limbo_sqlite3_parser/to_sql_string/stmt/
select.rs

1use std::fmt::Display;
2
3use crate::{
4    ast::{self, fmt::ToTokens},
5    to_sql_string::{ToSqlContext, ToSqlString},
6};
7
8impl ToSqlString for ast::Select {
9    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
10        let mut ret = Vec::new();
11        if let Some(with) = &self.with {
12            ret.push(with.to_sql_string(context));
13        }
14
15        ret.push(self.body.to_sql_string(context));
16
17        if let Some(order_by) = &self.order_by {
18            // TODO: SortedColumn missing collation in ast
19            let joined_cols = order_by
20                .iter()
21                .map(|col| col.to_sql_string(context))
22                .collect::<Vec<_>>()
23                .join(", ");
24            ret.push(format!("ORDER BY {}", joined_cols));
25        }
26        if let Some(limit) = &self.limit {
27            ret.push(limit.to_sql_string(context));
28        }
29        ret.join(" ")
30    }
31}
32
33impl ToSqlString for ast::SelectBody {
34    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
35        let mut ret = self.select.to_sql_string(context);
36
37        if let Some(compounds) = &self.compounds {
38            ret.push(' ');
39            let compound_selects = compounds
40                .iter()
41                .map(|compound_select| {
42                    let mut curr = compound_select.operator.to_string();
43                    curr.push(' ');
44                    curr.push_str(&compound_select.select.to_sql_string(context));
45                    curr
46                })
47                .collect::<Vec<_>>()
48                .join(" ");
49            ret.push_str(&compound_selects);
50        }
51        ret
52    }
53}
54
55impl ToSqlString for ast::OneSelect {
56    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
57        match self {
58            ast::OneSelect::Select(select) => select.to_sql_string(context),
59            ast::OneSelect::Values(values) => {
60                let joined_values = values
61                    .iter()
62                    .map(|value| {
63                        let joined_value = value
64                            .iter()
65                            .map(|e| e.to_sql_string(context))
66                            .collect::<Vec<_>>()
67                            .join(", ");
68                        format!("({})", joined_value)
69                    })
70                    .collect::<Vec<_>>()
71                    .join(", ");
72                format!("VALUES {}", joined_values)
73            }
74        }
75    }
76}
77
78impl ToSqlString for ast::SelectInner {
79    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
80        dbg!(&self);
81        let mut ret = Vec::with_capacity(2 + self.columns.len());
82        ret.push("SELECT".to_string());
83        if let Some(distinct) = self.distinctness {
84            ret.push(distinct.to_string());
85        }
86        let joined_cols = self
87            .columns
88            .iter()
89            .map(|col| col.to_sql_string(context))
90            .collect::<Vec<_>>()
91            .join(", ");
92        ret.push(joined_cols);
93
94        if let Some(from) = &self.from {
95            ret.push(from.to_sql_string(context));
96        }
97        if let Some(where_expr) = &self.where_clause {
98            ret.push("WHERE".to_string());
99            ret.push(where_expr.to_sql_string(context));
100        }
101        if let Some(group_by) = &self.group_by {
102            ret.push(group_by.to_sql_string(context));
103        }
104        if let Some(window_clause) = &self.window_clause {
105            ret.push("WINDOW".to_string());
106            let joined_window = window_clause
107                .iter()
108                .map(|window_def| window_def.to_sql_string(context))
109                .collect::<Vec<_>>()
110                .join(",");
111            ret.push(joined_window);
112        }
113
114        ret.join(" ")
115    }
116}
117
118impl ToSqlString for ast::FromClause {
119    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
120        let mut ret = String::from("FROM");
121        if let Some(select_table) = &self.select {
122            ret.push(' ');
123            ret.push_str(&select_table.to_sql_string(context));
124        }
125        if let Some(joins) = &self.joins {
126            ret.push(' ');
127            let joined_joins = joins
128                .iter()
129                .map(|join| {
130                    let mut curr = join.operator.to_string();
131                    curr.push(' ');
132                    curr.push_str(&join.table.to_sql_string(context));
133                    if let Some(join_constraint) = &join.constraint {
134                        curr.push(' ');
135                        curr.push_str(&join_constraint.to_sql_string(context));
136                    }
137                    curr
138                })
139                .collect::<Vec<_>>()
140                .join(" ");
141            ret.push_str(&joined_joins);
142        }
143        ret
144    }
145}
146
147impl ToSqlString for ast::SelectTable {
148    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
149        let mut ret = String::new();
150        match self {
151            Self::Table(name, alias, indexed) => {
152                ret.push_str(&name.to_sql_string(context));
153                if let Some(alias) = alias {
154                    ret.push(' ');
155                    ret.push_str(&alias.to_string());
156                }
157                if let Some(indexed) = indexed {
158                    ret.push(' ');
159                    ret.push_str(&indexed.to_string());
160                }
161            }
162            Self::TableCall(table_func, args, alias) => {
163                ret.push_str(&table_func.to_sql_string(context));
164                if let Some(args) = args {
165                    ret.push(' ');
166                    let joined_args = args
167                        .iter()
168                        .map(|arg| arg.to_sql_string(context))
169                        .collect::<Vec<_>>()
170                        .join(", ");
171                    ret.push_str(&joined_args);
172                }
173                if let Some(alias) = alias {
174                    ret.push(' ');
175                    ret.push_str(&alias.to_string());
176                }
177            }
178            Self::Select(select, alias) => {
179                ret.push('(');
180                ret.push_str(&select.to_sql_string(context));
181                ret.push(')');
182                if let Some(alias) = alias {
183                    ret.push(' ');
184                    ret.push_str(&alias.to_string());
185                }
186            }
187            Self::Sub(from_clause, alias) => {
188                ret.push('(');
189                ret.push_str(&from_clause.to_sql_string(context));
190                ret.push(')');
191                if let Some(alias) = alias {
192                    ret.push(' ');
193                    ret.push_str(&alias.to_string());
194                }
195            }
196        }
197        ret
198    }
199}
200
201impl ToSqlString for ast::With {
202    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
203        format!(
204            "WITH{} {}",
205            if self.recursive { " RECURSIVE " } else { "" },
206            self.ctes
207                .iter()
208                .map(|cte| cte.to_sql_string(context))
209                .collect::<Vec<_>>()
210                .join(", ")
211        )
212    }
213}
214
215impl ToSqlString for ast::Limit {
216    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
217        format!(
218            "LIMIT {}{}",
219            self.expr.to_sql_string(context),
220            self.offset
221                .as_ref()
222                .map_or("".to_string(), |offset| format!(
223                    " OFFSET {}",
224                    offset.to_sql_string(context)
225                ))
226        )
227        // TODO: missing , + expr in ast
228    }
229}
230
231impl ToSqlString for ast::CommonTableExpr {
232    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
233        let mut ret = Vec::with_capacity(self.columns.as_ref().map_or(2, |cols| cols.len()));
234        ret.push(self.tbl_name.0.clone());
235        if let Some(cols) = &self.columns {
236            let joined_cols = cols
237                .iter()
238                .map(|col| col.to_string())
239                .collect::<Vec<_>>()
240                .join(", ");
241
242            ret.push(format!("({})", joined_cols));
243        }
244        ret.push(format!(
245            "AS {}({})",
246            {
247                let mut materialized = self.materialized.to_string();
248                if !materialized.is_empty() {
249                    materialized.push(' ');
250                }
251                materialized
252            },
253            self.select.to_sql_string(context)
254        ));
255        ret.join(" ")
256    }
257}
258
259impl Display for ast::IndexedColumn {
260    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261        write!(f, "{}", self.col_name.0)
262    }
263}
264
265impl ToSqlString for ast::SortedColumn {
266    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
267        let mut curr = self.expr.to_sql_string(context);
268        if let Some(sort_order) = self.order {
269            curr.push(' ');
270            curr.push_str(&sort_order.to_string());
271        }
272        if let Some(nulls_order) = self.nulls {
273            curr.push(' ');
274            curr.push_str(&nulls_order.to_string());
275        }
276        curr
277    }
278}
279
280impl Display for ast::SortOrder {
281    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282        self.to_fmt(f)
283    }
284}
285
286impl Display for ast::NullsOrder {
287    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288        self.to_fmt(f)
289    }
290}
291
292impl Display for ast::Materialized {
293    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294        let value = match self {
295            Self::Any => "",
296            Self::No => "NOT MATERIALIZED",
297            Self::Yes => "MATERIALIZED",
298        };
299        write!(f, "{}", value)
300    }
301}
302
303impl ToSqlString for ast::ResultColumn {
304    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
305        let mut ret = String::new();
306        match self {
307            Self::Expr(expr, alias) => {
308                ret.push_str(&expr.to_sql_string(context));
309                if let Some(alias) = alias {
310                    ret.push(' ');
311                    ret.push_str(&alias.to_string());
312                }
313            }
314            Self::Star => {
315                ret.push('*');
316            }
317            Self::TableStar(name) => {
318                ret.push_str(&format!("{}.*", name.0));
319            }
320        }
321        ret
322    }
323}
324
325impl Display for ast::As {
326    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327        write!(
328            f,
329            "{}",
330            match self {
331                Self::As(alias) => {
332                    format!("AS {}", alias.0)
333                }
334                Self::Elided(alias) => alias.0.clone(),
335            }
336        )
337    }
338}
339
340impl Display for ast::Indexed {
341    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342        write!(
343            f,
344            "{}",
345            match self {
346                Self::NotIndexed => "NOT INDEXED".to_string(),
347                Self::IndexedBy(name) => format!("INDEXED BY {}", name.0),
348            }
349        )
350    }
351}
352
353impl Display for ast::JoinOperator {
354    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355        write!(
356            f,
357            "{}",
358            match self {
359                Self::Comma => ",".to_string(),
360                Self::TypedJoin(join) => {
361                    let join_keyword = "JOIN";
362                    if let Some(join) = join {
363                        format!("{} {}", join, join_keyword)
364                    } else {
365                        join_keyword.to_string()
366                    }
367                }
368            }
369        )
370    }
371}
372
373impl Display for ast::JoinType {
374    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
375        let value = {
376            let mut modifiers = Vec::new();
377            if self.contains(Self::NATURAL) {
378                modifiers.push("NATURAL");
379            }
380            if self.contains(Self::LEFT) || self.contains(Self::RIGHT) {
381                // TODO: I think the parser incorrectly asigns outer to every LEFT and RIGHT query
382                if self.contains(Self::LEFT | Self::RIGHT) {
383                    modifiers.push("FULL");
384                } else if self.contains(Self::LEFT) {
385                    modifiers.push("LEFT");
386                } else if self.contains(Self::RIGHT) {
387                    modifiers.push("RIGHT");
388                }
389                // FIXME: ignore outer joins as I think they are parsed incorrectly in the bitflags
390                // if self.contains(Self::OUTER) {
391                //     modifiers.push("OUTER");
392                // }
393            }
394
395            if self.contains(Self::INNER) {
396                modifiers.push("INNER");
397            }
398            if self.contains(Self::CROSS) {
399                modifiers.push("CROSS");
400            }
401            modifiers.join(" ")
402        };
403        write!(f, "{}", value)
404    }
405}
406
407impl ToSqlString for ast::JoinConstraint {
408    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
409        match self {
410            Self::On(expr) => {
411                format!("ON {}", expr.to_sql_string(context))
412            }
413            Self::Using(col_names) => {
414                let joined_names = col_names
415                    .iter()
416                    .map(|col| col.0.clone())
417                    .collect::<Vec<_>>()
418                    .join(",");
419                format!("USING ({})", joined_names)
420            }
421        }
422    }
423}
424
425impl ToSqlString for ast::GroupBy {
426    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
427        let mut ret = String::from("GROUP BY ");
428        let curr = self
429            .exprs
430            .iter()
431            .map(|expr| expr.to_sql_string(context))
432            .collect::<Vec<_>>()
433            .join(",");
434        ret.push_str(&curr);
435        if let Some(having) = &self.having {
436            ret.push_str(&format!(" HAVING {}", having.to_sql_string(context)));
437        }
438        ret
439    }
440}
441
442impl ToSqlString for ast::WindowDef {
443    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
444        format!("{} AS {}", self.name.0, self.window.to_sql_string(context))
445    }
446}
447
448impl ToSqlString for ast::Window {
449    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
450        let mut ret = Vec::new();
451        if let Some(name) = &self.base {
452            ret.push(name.0.clone());
453        }
454        if let Some(partition) = &self.partition_by {
455            let joined_exprs = partition
456                .iter()
457                .map(|e| e.to_sql_string(context))
458                .collect::<Vec<_>>()
459                .join(",");
460            ret.push(format!("PARTITION BY {}", joined_exprs));
461        }
462        if let Some(order_by) = &self.order_by {
463            let joined_cols = order_by
464                .iter()
465                .map(|col| col.to_sql_string(context))
466                .collect::<Vec<_>>()
467                .join(", ");
468            ret.push(format!("ORDER BY {}", joined_cols));
469        }
470        if let Some(frame_claue) = &self.frame_clause {
471            ret.push(frame_claue.to_sql_string(context));
472        }
473        format!("({})", ret.join(" "))
474    }
475}
476
477impl ToSqlString for ast::FrameClause {
478    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
479        let mut ret = Vec::new();
480        ret.push(self.mode.to_string());
481        let start_sql = self.start.to_sql_string(context);
482        if let Some(end) = &self.end {
483            ret.push(format!(
484                "BETWEEN {} AND {}",
485                start_sql,
486                end.to_sql_string(context)
487            ));
488        } else {
489            ret.push(start_sql);
490        }
491        if let Some(exclude) = &self.exclude {
492            ret.push(exclude.to_string());
493        }
494
495        ret.join(" ")
496    }
497}
498
499impl Display for ast::FrameMode {
500    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
501        self.to_fmt(f)
502    }
503}
504
505impl ToSqlString for ast::FrameBound {
506    fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
507        match self {
508            Self::CurrentRow => "CURRENT ROW".to_string(),
509            Self::Following(expr) => format!("{} FOLLOWING", expr.to_sql_string(context)),
510            Self::Preceding(expr) => format!("{} PRECEDING", expr.to_sql_string(context)),
511            Self::UnboundedFollowing => "UNBOUNDED FOLLOWING".to_string(),
512            Self::UnboundedPreceding => "UNBOUNDED PRECEDING".to_string(),
513        }
514    }
515}
516
517impl Display for ast::FrameExclude {
518    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
519        write!(f, "{}", {
520            let clause = match self {
521                Self::CurrentRow => "CURRENT ROW",
522                Self::Group => "GROUP",
523                Self::NoOthers => "NO OTHERS",
524                Self::Ties => "TIES",
525            };
526            format!("EXCLUDE {}", clause)
527        })
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use crate::to_sql_string_test;
534
535    to_sql_string_test!(test_select_basic, "SELECT 1;");
536
537    to_sql_string_test!(test_select_table, "SELECT * FROM t;");
538
539    to_sql_string_test!(test_select_table_2, "SELECT a FROM t;");
540
541    to_sql_string_test!(test_select_multiple_columns, "SELECT a, b, c FROM t;");
542
543    to_sql_string_test!(test_select_with_alias, "SELECT a AS col1 FROM t;");
544
545    to_sql_string_test!(test_select_with_table_alias, "SELECT t1.a FROM t AS t1;");
546
547    to_sql_string_test!(test_select_with_where, "SELECT a FROM t WHERE b = 1;");
548
549    to_sql_string_test!(
550        test_select_with_multiple_conditions,
551        "SELECT a FROM t WHERE b = 1 AND c > 2;"
552    );
553
554    to_sql_string_test!(
555        test_select_with_order_by,
556        "SELECT a FROM t ORDER BY a DESC;"
557    );
558
559    to_sql_string_test!(test_select_with_limit, "SELECT a FROM t LIMIT 10;");
560
561    to_sql_string_test!(
562        test_select_with_offset,
563        "SELECT a FROM t LIMIT 10 OFFSET 5;"
564    );
565
566    to_sql_string_test!(
567        test_select_with_join,
568        "SELECT a FROM t JOIN t2 ON t.b = t2.b;"
569    );
570
571    to_sql_string_test!(
572        test_select_with_group_by,
573        "SELECT a, COUNT(*) FROM t GROUP BY a;"
574    );
575
576    to_sql_string_test!(
577        test_select_with_having,
578        "SELECT a, COUNT(*) FROM t GROUP BY a HAVING COUNT(*) > 1;"
579    );
580
581    to_sql_string_test!(test_select_with_distinct, "SELECT DISTINCT a FROM t;");
582
583    to_sql_string_test!(test_select_with_function, "SELECT COUNT(a) FROM t;");
584
585    to_sql_string_test!(
586        test_select_with_subquery,
587        "SELECT a FROM (SELECT b FROM t) AS sub;"
588    );
589
590    to_sql_string_test!(
591        test_select_nested_subquery,
592        "SELECT a FROM (SELECT b FROM (SELECT c FROM t WHERE c > 10) AS sub1 WHERE b < 20) AS sub2;"
593    );
594
595    to_sql_string_test!(
596        test_select_multiple_joins,
597        "SELECT t1.a, t2.b, t3.c FROM t1 JOIN t2 ON t1.id = t2.id LEFT JOIN t3 ON t2.id = t3.id;"
598    );
599
600    to_sql_string_test!(
601        test_select_with_cte,
602        "WITH cte AS (SELECT a FROM t WHERE b = 1) SELECT a FROM cte WHERE a > 10;"
603    );
604
605    to_sql_string_test!(
606        test_select_with_window_function,
607        "SELECT a, ROW_NUMBER() OVER (PARTITION BY b ORDER BY c DESC) AS rn FROM t;"
608    );
609
610    to_sql_string_test!(
611        test_select_with_complex_where,
612        "SELECT a FROM t WHERE b IN (1, 2, 3) AND c BETWEEN 10 AND 20 OR d IS NULL;"
613    );
614
615    to_sql_string_test!(
616        test_select_with_case,
617        "SELECT CASE WHEN a > 0 THEN 'positive' ELSE 'non-positive' END AS result FROM t;"
618    );
619
620    to_sql_string_test!(test_select_with_aggregate_and_join, "SELECT t1.a, COUNT(t2.b) FROM t1 LEFT JOIN t2 ON t1.id = t2.id GROUP BY t1.a HAVING COUNT(t2.b) > 5;");
621
622    to_sql_string_test!(test_select_with_multiple_ctes, "WITH cte1 AS (SELECT a FROM t WHERE b = 1), cte2 AS (SELECT c FROM t2 WHERE d = 2) SELECT cte1.a, cte2.c FROM cte1 JOIN cte2 ON cte1.a = cte2.c;");
623
624    to_sql_string_test!(
625        test_select_with_union,
626        "SELECT a FROM t1 UNION SELECT b FROM t2;"
627    );
628
629    to_sql_string_test!(
630        test_select_with_union_all,
631        "SELECT a FROM t1 UNION ALL SELECT b FROM t2;"
632    );
633
634    to_sql_string_test!(
635        test_select_with_exists,
636        "SELECT a FROM t WHERE EXISTS (SELECT 1 FROM t2 WHERE t2.b = t.a);"
637    );
638
639    to_sql_string_test!(
640        test_select_with_correlated_subquery,
641        "SELECT a, (SELECT COUNT(*) FROM t2 WHERE t2.b = t.a) AS count_b FROM t;"
642    );
643
644    to_sql_string_test!(
645        test_select_with_complex_order_by,
646        "SELECT a, b FROM t ORDER BY CASE WHEN a IS NULL THEN 1 ELSE 0 END, b ASC, c DESC;"
647    );
648
649    to_sql_string_test!(
650        test_select_with_full_outer_join,
651        "SELECT t1.a, t2.b FROM t1 FULL OUTER JOIN t2 ON t1.id = t2.id;",
652        ignore = "OUTER JOIN is incorrectly parsed in parser"
653    );
654
655    to_sql_string_test!(test_select_with_aggregate_window, "SELECT a, SUM(b) OVER (PARTITION BY c ORDER BY d ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS running_sum FROM t;");
656
657    to_sql_string_test!(
658        test_select_with_exclude,
659        "SELECT 
660    c.name,
661    o.order_id,
662    o.order_amount,
663    SUM(o.order_amount) OVER (PARTITION BY c.id
664        ORDER BY o.order_date
665        ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
666        EXCLUDE CURRENT ROW) AS running_total_excluding_current
667FROM customers c
668JOIN orders o ON c.id = o.customer_id
669WHERE EXISTS (SELECT 1
670    FROM orders o2
671    WHERE o2.customer_id = c.id
672    AND o2.order_amount > 1000);"
673    );
674}