Skip to main content

limbo_sqlite3_parser/to_sql_string/stmt/
select.rs

1use std::fmt::Display;
2
3use crate::{
4    ast::{self, fmt::ToTokens},
5    to_sql_string::{ToSqlContext, ToSqlString},
6};
7
8impl ToSqlString for ast::Select {
9    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
10        let mut ret = Vec::new();
11        if let Some(with) = &self.with {
12            ret.push(with.to_sql_string(context));
13        }
14
15        ret.push(self.body.to_sql_string(context));
16
17        if let Some(order_by) = &self.order_by {
18            // TODO: SortedColumn missing collation in ast
19            let joined_cols = order_by
20                .iter()
21                .map(|col| col.to_sql_string(context))
22                .collect::<Vec<_>>()
23                .join(", ");
24            ret.push(format!("ORDER BY {}", joined_cols));
25        }
26        if let Some(limit) = &self.limit {
27            ret.push(limit.to_sql_string(context));
28        }
29        ret.join(" ")
30    }
31}
32
33impl ToSqlString for ast::SelectBody {
34    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
35        let mut ret = self.select.to_sql_string(context);
36
37        if let Some(compounds) = &self.compounds {
38            ret.push(' ');
39            let compound_selects = compounds
40                .iter()
41                .map(|compound_select| {
42                    let mut curr = compound_select.operator.to_string();
43                    curr.push(' ');
44                    curr.push_str(&compound_select.select.to_sql_string(context));
45                    curr
46                })
47                .collect::<Vec<_>>()
48                .join(" ");
49            ret.push_str(&compound_selects);
50        }
51        ret
52    }
53}
54
55impl ToSqlString for ast::OneSelect {
56    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
57        match self {
58            ast::OneSelect::Select(select) => select.to_sql_string(context),
59            ast::OneSelect::Values(values) => {
60                let joined_values = values
61                    .iter()
62                    .map(|value| {
63                        let joined_value = value
64                            .iter()
65                            .map(|e| e.to_sql_string(context))
66                            .collect::<Vec<_>>()
67                            .join(", ");
68                        format!("({})", joined_value)
69                    })
70                    .collect::<Vec<_>>()
71                    .join(", ");
72                format!("VALUES {}", joined_values)
73            }
74        }
75    }
76}
77
78impl ToSqlString for ast::SelectInner {
79    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
80        let mut ret = Vec::with_capacity(2 + self.columns.len());
81        ret.push("SELECT".to_string());
82        if let Some(distinct) = self.distinctness {
83            ret.push(distinct.to_string());
84        }
85        let joined_cols = self
86            .columns
87            .iter()
88            .map(|col| col.to_sql_string(context))
89            .collect::<Vec<_>>()
90            .join(", ");
91        ret.push(joined_cols);
92
93        if let Some(from) = &self.from {
94            ret.push(from.to_sql_string(context));
95        }
96        if let Some(where_expr) = &self.where_clause {
97            ret.push("WHERE".to_string());
98            ret.push(where_expr.to_sql_string(context));
99        }
100        if let Some(group_by) = &self.group_by {
101            ret.push(group_by.to_sql_string(context));
102        }
103        if let Some(window_clause) = &self.window_clause {
104            ret.push("WINDOW".to_string());
105            let joined_window = window_clause
106                .iter()
107                .map(|window_def| window_def.to_sql_string(context))
108                .collect::<Vec<_>>()
109                .join(",");
110            ret.push(joined_window);
111        }
112
113        ret.join(" ")
114    }
115}
116
117impl ToSqlString for ast::FromClause {
118    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
119        let mut ret = String::from("FROM");
120        if let Some(select_table) = &self.select {
121            ret.push(' ');
122            ret.push_str(&select_table.to_sql_string(context));
123        }
124        if let Some(joins) = &self.joins {
125            ret.push(' ');
126            let joined_joins = joins
127                .iter()
128                .map(|join| {
129                    let mut curr = join.operator.to_string();
130                    curr.push(' ');
131                    curr.push_str(&join.table.to_sql_string(context));
132                    if let Some(join_constraint) = &join.constraint {
133                        curr.push(' ');
134                        curr.push_str(&join_constraint.to_sql_string(context));
135                    }
136                    curr
137                })
138                .collect::<Vec<_>>()
139                .join(" ");
140            ret.push_str(&joined_joins);
141        }
142        ret
143    }
144}
145
146impl ToSqlString for ast::SelectTable {
147    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
148        let mut ret = String::new();
149        match self {
150            Self::Table(name, alias, indexed) => {
151                ret.push_str(&name.to_sql_string(context));
152                if let Some(alias) = alias {
153                    ret.push(' ');
154                    ret.push_str(&alias.to_string());
155                }
156                if let Some(indexed) = indexed {
157                    ret.push(' ');
158                    ret.push_str(&indexed.to_string());
159                }
160            }
161            Self::TableCall(table_func, args, alias) => {
162                ret.push_str(&table_func.to_sql_string(context));
163                if let Some(args) = args {
164                    ret.push(' ');
165                    let joined_args = args
166                        .iter()
167                        .map(|arg| arg.to_sql_string(context))
168                        .collect::<Vec<_>>()
169                        .join(", ");
170                    ret.push_str(&joined_args);
171                }
172                if let Some(alias) = alias {
173                    ret.push(' ');
174                    ret.push_str(&alias.to_string());
175                }
176            }
177            Self::Select(select, alias) => {
178                ret.push('(');
179                ret.push_str(&select.to_sql_string(context));
180                ret.push(')');
181                if let Some(alias) = alias {
182                    ret.push(' ');
183                    ret.push_str(&alias.to_string());
184                }
185            }
186            Self::Sub(from_clause, alias) => {
187                ret.push('(');
188                ret.push_str(&from_clause.to_sql_string(context));
189                ret.push(')');
190                if let Some(alias) = alias {
191                    ret.push(' ');
192                    ret.push_str(&alias.to_string());
193                }
194            }
195        }
196        ret
197    }
198}
199
200impl ToSqlString for ast::With {
201    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
202        format!(
203            "WITH{} {}",
204            if self.recursive { " RECURSIVE " } else { "" },
205            self.ctes
206                .iter()
207                .map(|cte| cte.to_sql_string(context))
208                .collect::<Vec<_>>()
209                .join(", ")
210        )
211    }
212}
213
214impl ToSqlString for ast::Limit {
215    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
216        format!(
217            "LIMIT {}{}",
218            self.expr.to_sql_string(context),
219            self.offset
220                .as_ref()
221                .map_or("".to_string(), |offset| format!(
222                    " OFFSET {}",
223                    offset.to_sql_string(context)
224                ))
225        )
226        // TODO: missing , + expr in ast
227    }
228}
229
230impl ToSqlString for ast::CommonTableExpr {
231    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
232        let mut ret = Vec::with_capacity(self.columns.as_ref().map_or(2, |cols| cols.len()));
233        ret.push(self.tbl_name.0.clone());
234        if let Some(cols) = &self.columns {
235            let joined_cols = cols
236                .iter()
237                .map(|col| col.to_string())
238                .collect::<Vec<_>>()
239                .join(", ");
240
241            ret.push(format!("({})", joined_cols));
242        }
243        ret.push(format!(
244            "AS {}({})",
245            {
246                let mut materialized = self.materialized.to_string();
247                if !materialized.is_empty() {
248                    materialized.push(' ');
249                }
250                materialized
251            },
252            self.select.to_sql_string(context)
253        ));
254        ret.join(" ")
255    }
256}
257
258impl Display for ast::IndexedColumn {
259    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260        write!(f, "{}", self.col_name.0)
261    }
262}
263
264impl ToSqlString for ast::SortedColumn {
265    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
266        let mut curr = self.expr.to_sql_string(context);
267        if let Some(sort_order) = self.order {
268            curr.push(' ');
269            curr.push_str(&sort_order.to_string());
270        }
271        if let Some(nulls_order) = self.nulls {
272            curr.push(' ');
273            curr.push_str(&nulls_order.to_string());
274        }
275        curr
276    }
277}
278
279impl Display for ast::SortOrder {
280    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281        self.to_fmt(f)
282    }
283}
284
285impl Display for ast::NullsOrder {
286    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
287        self.to_fmt(f)
288    }
289}
290
291impl Display for ast::Materialized {
292    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293        let value = match self {
294            Self::Any => "",
295            Self::No => "NOT MATERIALIZED",
296            Self::Yes => "MATERIALIZED",
297        };
298        write!(f, "{}", value)
299    }
300}
301
302impl ToSqlString for ast::ResultColumn {
303    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
304        let mut ret = String::new();
305        match self {
306            Self::Expr(expr, alias) => {
307                ret.push_str(&expr.to_sql_string(context));
308                if let Some(alias) = alias {
309                    ret.push(' ');
310                    ret.push_str(&alias.to_string());
311                }
312            }
313            Self::Star => {
314                ret.push('*');
315            }
316            Self::TableStar(name) => {
317                ret.push_str(&format!("{}.*", name.0));
318            }
319        }
320        ret
321    }
322}
323
324impl Display for ast::As {
325    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326        write!(
327            f,
328            "{}",
329            match self {
330                Self::As(alias) => {
331                    format!("AS {}", alias.0)
332                }
333                Self::Elided(alias) => alias.0.clone(),
334            }
335        )
336    }
337}
338
339impl Display for ast::Indexed {
340    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341        write!(
342            f,
343            "{}",
344            match self {
345                Self::NotIndexed => "NOT INDEXED".to_string(),
346                Self::IndexedBy(name) => format!("INDEXED BY {}", name.0),
347            }
348        )
349    }
350}
351
352impl Display for ast::JoinOperator {
353    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
354        write!(
355            f,
356            "{}",
357            match self {
358                Self::Comma => ",".to_string(),
359                Self::TypedJoin(join) => {
360                    let join_keyword = "JOIN";
361                    if let Some(join) = join {
362                        format!("{} {}", join, join_keyword)
363                    } else {
364                        join_keyword.to_string()
365                    }
366                }
367            }
368        )
369    }
370}
371
372impl Display for ast::JoinType {
373    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374        let value = {
375            let mut modifiers = Vec::new();
376            if self.contains(Self::NATURAL) {
377                modifiers.push("NATURAL");
378            }
379            if self.contains(Self::LEFT) || self.contains(Self::RIGHT) {
380                // TODO: I think the parser incorrectly asigns outer to every LEFT and RIGHT query
381                if self.contains(Self::LEFT | Self::RIGHT) {
382                    modifiers.push("FULL");
383                } else if self.contains(Self::LEFT) {
384                    modifiers.push("LEFT");
385                } else if self.contains(Self::RIGHT) {
386                    modifiers.push("RIGHT");
387                }
388                // FIXME: ignore outer joins as I think they are parsed incorrectly in the bitflags
389                // if self.contains(Self::OUTER) {
390                //     modifiers.push("OUTER");
391                // }
392            }
393
394            if self.contains(Self::INNER) {
395                modifiers.push("INNER");
396            }
397            if self.contains(Self::CROSS) {
398                modifiers.push("CROSS");
399            }
400            modifiers.join(" ")
401        };
402        write!(f, "{}", value)
403    }
404}
405
406impl ToSqlString for ast::JoinConstraint {
407    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
408        match self {
409            Self::On(expr) => {
410                format!("ON {}", expr.to_sql_string(context))
411            }
412            Self::Using(col_names) => {
413                let joined_names = col_names
414                    .iter()
415                    .map(|col| col.0.clone())
416                    .collect::<Vec<_>>()
417                    .join(",");
418                format!("USING ({})", joined_names)
419            }
420        }
421    }
422}
423
424impl ToSqlString for ast::GroupBy {
425    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
426        let mut ret = String::from("GROUP BY ");
427        let curr = self
428            .exprs
429            .iter()
430            .map(|expr| expr.to_sql_string(context))
431            .collect::<Vec<_>>()
432            .join(",");
433        ret.push_str(&curr);
434        if let Some(having) = &self.having {
435            ret.push_str(&format!(" HAVING {}", having.to_sql_string(context)));
436        }
437        ret
438    }
439}
440
441impl ToSqlString for ast::WindowDef {
442    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
443        format!("{} AS {}", self.name.0, self.window.to_sql_string(context))
444    }
445}
446
447impl ToSqlString for ast::Window {
448    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
449        let mut ret = Vec::new();
450        if let Some(name) = &self.base {
451            ret.push(name.0.clone());
452        }
453        if let Some(partition) = &self.partition_by {
454            let joined_exprs = partition
455                .iter()
456                .map(|e| e.to_sql_string(context))
457                .collect::<Vec<_>>()
458                .join(",");
459            ret.push(format!("PARTITION BY {}", joined_exprs));
460        }
461        if let Some(order_by) = &self.order_by {
462            let joined_cols = order_by
463                .iter()
464                .map(|col| col.to_sql_string(context))
465                .collect::<Vec<_>>()
466                .join(", ");
467            ret.push(format!("ORDER BY {}", joined_cols));
468        }
469        if let Some(frame_claue) = &self.frame_clause {
470            ret.push(frame_claue.to_sql_string(context));
471        }
472        format!("({})", ret.join(" "))
473    }
474}
475
476impl ToSqlString for ast::FrameClause {
477    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
478        let mut ret = Vec::new();
479        ret.push(self.mode.to_string());
480        let start_sql = self.start.to_sql_string(context);
481        if let Some(end) = &self.end {
482            ret.push(format!(
483                "BETWEEN {} AND {}",
484                start_sql,
485                end.to_sql_string(context)
486            ));
487        } else {
488            ret.push(start_sql);
489        }
490        if let Some(exclude) = &self.exclude {
491            ret.push(exclude.to_string());
492        }
493
494        ret.join(" ")
495    }
496}
497
498impl Display for ast::FrameMode {
499    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
500        self.to_fmt(f)
501    }
502}
503
504impl ToSqlString for ast::FrameBound {
505    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
506        match self {
507            Self::CurrentRow => "CURRENT ROW".to_string(),
508            Self::Following(expr) => format!("{} FOLLOWING", expr.to_sql_string(context)),
509            Self::Preceding(expr) => format!("{} PRECEDING", expr.to_sql_string(context)),
510            Self::UnboundedFollowing => "UNBOUNDED FOLLOWING".to_string(),
511            Self::UnboundedPreceding => "UNBOUNDED PRECEDING".to_string(),
512        }
513    }
514}
515
516impl Display for ast::FrameExclude {
517    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
518        write!(f, "{}", {
519            let clause = match self {
520                Self::CurrentRow => "CURRENT ROW",
521                Self::Group => "GROUP",
522                Self::NoOthers => "NO OTHERS",
523                Self::Ties => "TIES",
524            };
525            format!("EXCLUDE {}", clause)
526        })
527    }
528}
529
530#[cfg(test)]
531mod tests {
532    use crate::to_sql_string_test;
533
534    to_sql_string_test!(test_select_basic, "SELECT 1;");
535
536    to_sql_string_test!(test_select_table, "SELECT * FROM t;");
537
538    to_sql_string_test!(test_select_table_2, "SELECT a FROM t;");
539
540    to_sql_string_test!(test_select_multiple_columns, "SELECT a, b, c FROM t;");
541
542    to_sql_string_test!(test_select_with_alias, "SELECT a AS col1 FROM t;");
543
544    to_sql_string_test!(test_select_with_table_alias, "SELECT t1.a FROM t AS t1;");
545
546    to_sql_string_test!(test_select_with_where, "SELECT a FROM t WHERE b = 1;");
547
548    to_sql_string_test!(
549        test_select_with_multiple_conditions,
550        "SELECT a FROM t WHERE b = 1 AND c > 2;"
551    );
552
553    to_sql_string_test!(
554        test_select_with_order_by,
555        "SELECT a FROM t ORDER BY a DESC;"
556    );
557
558    to_sql_string_test!(test_select_with_limit, "SELECT a FROM t LIMIT 10;");
559
560    to_sql_string_test!(
561        test_select_with_offset,
562        "SELECT a FROM t LIMIT 10 OFFSET 5;"
563    );
564
565    to_sql_string_test!(
566        test_select_with_join,
567        "SELECT a FROM t JOIN t2 ON t.b = t2.b;"
568    );
569
570    to_sql_string_test!(
571        test_select_with_group_by,
572        "SELECT a, COUNT(*) FROM t GROUP BY a;"
573    );
574
575    to_sql_string_test!(
576        test_select_with_having,
577        "SELECT a, COUNT(*) FROM t GROUP BY a HAVING COUNT(*) > 1;"
578    );
579
580    to_sql_string_test!(test_select_with_distinct, "SELECT DISTINCT a FROM t;");
581
582    to_sql_string_test!(test_select_with_function, "SELECT COUNT(a) FROM t;");
583
584    to_sql_string_test!(
585        test_select_with_subquery,
586        "SELECT a FROM (SELECT b FROM t) AS sub;"
587    );
588
589    to_sql_string_test!(
590        test_select_nested_subquery,
591        "SELECT a FROM (SELECT b FROM (SELECT c FROM t WHERE c > 10) AS sub1 WHERE b < 20) AS sub2;"
592    );
593
594    to_sql_string_test!(
595        test_select_multiple_joins,
596        "SELECT t1.a, t2.b, t3.c FROM t1 JOIN t2 ON t1.id = t2.id LEFT JOIN t3 ON t2.id = t3.id;"
597    );
598
599    to_sql_string_test!(
600        test_select_with_cte,
601        "WITH cte AS (SELECT a FROM t WHERE b = 1) SELECT a FROM cte WHERE a > 10;"
602    );
603
604    to_sql_string_test!(
605        test_select_with_window_function,
606        "SELECT a, ROW_NUMBER() OVER (PARTITION BY b ORDER BY c DESC) AS rn FROM t;"
607    );
608
609    to_sql_string_test!(
610        test_select_with_complex_where,
611        "SELECT a FROM t WHERE b IN (1, 2, 3) AND c BETWEEN 10 AND 20 OR d IS NULL;"
612    );
613
614    to_sql_string_test!(
615        test_select_with_case,
616        "SELECT CASE WHEN a > 0 THEN 'positive' ELSE 'non-positive' END AS result FROM t;"
617    );
618
619    to_sql_string_test!(test_select_with_aggregate_and_join, "SELECT t1.a, COUNT(t2.b) FROM t1 LEFT JOIN t2 ON t1.id = t2.id GROUP BY t1.a HAVING COUNT(t2.b) > 5;");
620
621    to_sql_string_test!(test_select_with_multiple_ctes, "WITH cte1 AS (SELECT a FROM t WHERE b = 1), cte2 AS (SELECT c FROM t2 WHERE d = 2) SELECT cte1.a, cte2.c FROM cte1 JOIN cte2 ON cte1.a = cte2.c;");
622
623    to_sql_string_test!(
624        test_select_with_union,
625        "SELECT a FROM t1 UNION SELECT b FROM t2;"
626    );
627
628    to_sql_string_test!(
629        test_select_with_union_all,
630        "SELECT a FROM t1 UNION ALL SELECT b FROM t2;"
631    );
632
633    to_sql_string_test!(
634        test_select_with_exists,
635        "SELECT a FROM t WHERE EXISTS (SELECT 1 FROM t2 WHERE t2.b = t.a);"
636    );
637
638    to_sql_string_test!(
639        test_select_with_correlated_subquery,
640        "SELECT a, (SELECT COUNT(*) FROM t2 WHERE t2.b = t.a) AS count_b FROM t;"
641    );
642
643    to_sql_string_test!(
644        test_select_with_complex_order_by,
645        "SELECT a, b FROM t ORDER BY CASE WHEN a IS NULL THEN 1 ELSE 0 END, b ASC, c DESC;"
646    );
647
648    to_sql_string_test!(
649        test_select_with_full_outer_join,
650        "SELECT t1.a, t2.b FROM t1 FULL OUTER JOIN t2 ON t1.id = t2.id;",
651        ignore = "OUTER JOIN is incorrectly parsed in parser"
652    );
653
654    to_sql_string_test!(test_select_with_aggregate_window, "SELECT a, SUM(b) OVER (PARTITION BY c ORDER BY d ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS running_sum FROM t;");
655
656    to_sql_string_test!(
657        test_select_with_exclude,
658        "SELECT 
659    c.name,
660    o.order_id,
661    o.order_amount,
662    SUM(o.order_amount) OVER (PARTITION BY c.id
663        ORDER BY o.order_date
664        ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
665        EXCLUDE CURRENT ROW) AS running_total_excluding_current
666FROM customers c
667JOIN orders o ON c.id = o.customer_id
668WHERE EXISTS (SELECT 1
669    FROM orders o2
670    WHERE o2.customer_id = c.id
671    AND o2.order_amount > 1000);"
672    );
673}