Skip to main content

squawk_fmt/
fmt.rs

1use itertools::Itertools;
2use rowan::Direction;
3use squawk_syntax::ast::{self, AstNode, LitKind};
4use squawk_syntax::quote::quote_column_alias;
5use squawk_syntax::{SyntaxKind, SyntaxNode, SyntaxToken};
6use tiny_pretty::Doc;
7use tiny_pretty::{PrintOptions, print};
8
9// TODO: anytime we have `syntax().to_string()`, it means we have to do more to
10// actually convert the data into the IR. to_string() is a temp hack
11
12fn build_source_file(source_file: &ast::SourceFile) -> Doc<'_> {
13    let mut doc = Doc::nil();
14    for el in source_file.syntax().children_with_tokens() {
15        match el {
16            rowan::NodeOrToken::Node(node) => {
17                if let Some(stmt) = ast::Stmt::cast(node) {
18                    match stmt {
19                        ast::Stmt::Select(select) => {
20                            doc = doc.append(build_select_doc(&select));
21                        }
22                        ast::Stmt::CreateTable(create_table) => {
23                            doc = doc.append(build_create_table(&create_table));
24                        }
25                        _ => (),
26                    }
27                }
28            }
29            rowan::NodeOrToken::Token(token) => {
30                if token.kind() == SyntaxKind::COMMENT {
31                    doc = doc.append(Doc::text(token.text().to_string()));
32                } else if token.kind() == SyntaxKind::WHITESPACE {
33                    // TODO: I think we can improve this
34                    let lines = token.text().lines().count();
35                    if lines >= 2 {
36                        doc = doc.append(Doc::empty_line()).append(Doc::empty_line());
37                    } else {
38                        doc = doc.append(Doc::empty_line());
39                    }
40                }
41            }
42        }
43    }
44    doc
45}
46
47fn build_create_table<'a>(create_table: &ast::CreateTable) -> Doc<'a> {
48    let mut doc = Doc::text("create")
49        .append(Doc::space())
50        .append(Doc::text("table"))
51        .append(Doc::space())
52        .append(Doc::text(
53            create_table.path().map(|x| x.syntax().to_string()).unwrap(),
54        ))
55        .append(Doc::text("("))
56        .append(
57            Doc::line_or_nil()
58                .append(Doc::list(
59                    Itertools::intersperse(
60                        create_table
61                            .table_arg_list()
62                            .unwrap()
63                            .args()
64                            .map(build_table_arg),
65                        Doc::text(",").append(Doc::hard_line()),
66                    )
67                    .collect(),
68                ))
69                .nest(2)
70                .append(Doc::line_or_nil())
71                .group(),
72        )
73        .append(Doc::text(")"));
74
75    doc = doc.append(build_semicolon(create_table.semicolon_token()));
76
77    doc
78}
79
80fn build_table_arg<'a>(create_table: ast::TableArg) -> Doc<'a> {
81    match create_table {
82        ast::TableArg::Column(column) => build_name(column.name().unwrap())
83            .append(Doc::space())
84            .append(Doc::text(column.ty().unwrap().syntax().to_string())),
85        ast::TableArg::LikeClause(_like_clause) => todo!(),
86        ast::TableArg::TableConstraint(_table_constraint) => todo!(),
87    }
88}
89
90fn build_select_doc<'a>(select: &ast::Select) -> Doc<'a> {
91    let mut doc = Doc::text("select").append(Doc::line_or_space());
92
93    if let Some(select_clause) = select.select_clause() {
94        if let Some(distinct_clause) = select_clause.distinct_clause() {
95            doc = doc.append(leading_comments(distinct_clause.syntax()));
96            doc = doc.append(Doc::text("distinct")).append(Doc::space());
97        }
98        if let Some(all_token) = select_clause.all_token() {
99            doc = doc.append(leading_comments_token(&all_token));
100            doc = doc.append(Doc::text("all")).append(Doc::space());
101        }
102        if let Some(target_list) = select_clause.target_list() {
103            doc = doc.append(leading_comments(target_list.syntax()));
104            doc = doc
105                .append(Doc::list(
106                    Itertools::intersperse(
107                        target_list.targets().flat_map(build_target),
108                        Doc::text(",").append(Doc::line_or_space()),
109                    )
110                    .collect(),
111                ))
112                .nest(2);
113        }
114    }
115
116    if let Some(from) = &select.from_clause() {
117        doc = doc.append(
118            Doc::line_or_space()
119                .append(Doc::text("from"))
120                .append(Doc::space())
121                .append(Doc::text(
122                    from.from_items().next().unwrap().syntax().to_string(),
123                )),
124        );
125    }
126
127    if let Some(group) = &select.group_by_clause() {
128        doc = doc.append(
129            Doc::line_or_space()
130                .append(Doc::text("group by"))
131                .append(Doc::space())
132                .append(Doc::text(
133                    group.group_by_list().unwrap().syntax().to_string(),
134                )),
135        );
136    }
137
138    doc = doc.append(build_semicolon(select.semicolon_token()));
139
140    doc.group()
141}
142
143fn build_semicolon<'a>(semi: Option<SyntaxToken>) -> Doc<'a> {
144    let Some(semi) = semi else {
145        return Doc::nil();
146    };
147    let mut doc = Doc::nil();
148    let mut comments: Vec<SyntaxToken> = vec![];
149    for next in semi.siblings_with_tokens(Direction::Prev).skip(1) {
150        match next {
151            rowan::NodeOrToken::Node(_) => break,
152            rowan::NodeOrToken::Token(token) => {
153                if token.kind() == SyntaxKind::COMMENT {
154                    comments.push(token);
155                } else if token.kind() == SyntaxKind::WHITESPACE {
156                    continue;
157                } else {
158                    break;
159                }
160            }
161        }
162    }
163    for comment in comments.iter().rev() {
164        doc = doc.append(Doc::text(comment.text().to_string()));
165    }
166    doc.append(Doc::text(";"))
167}
168
169fn build_expr<'a>(expr: ast::Expr) -> Doc<'a> {
170    match expr {
171        ast::Expr::ArrayExpr(array_expr) => {
172            let mut doc = Doc::nil();
173
174            // nested parts of array expressions don't require the array token
175            if array_expr.array_token().is_some() {
176                doc = doc.append(Doc::text("array"));
177            };
178
179            if let Some(select) = array_expr.select() {
180                doc = doc
181                    .append(Doc::text("("))
182                    .append(build_select_doc(&select))
183                    .append(Doc::text(")"))
184            } else {
185                doc = doc
186                    .append(Doc::text("["))
187                    .append(Doc::list(
188                        Itertools::intersperse(
189                            array_expr.exprs().map(build_expr),
190                            Doc::text(",").append(Doc::space()),
191                        )
192                        .collect(),
193                    ))
194                    .append(Doc::text("]"));
195            }
196
197            doc
198        }
199        ast::Expr::BetweenExpr(between_expr) => {
200            let mut doc = build_expr(between_expr.target().unwrap());
201            if between_expr.not_token().is_some() {
202                doc = doc.append(Doc::space()).append(Doc::text("not"));
203            }
204            doc = doc.append(Doc::space()).append(Doc::text("between"));
205            if between_expr.symmetric_token().is_some() {
206                doc = doc.append(Doc::space()).append(Doc::text("symmetric"));
207            }
208            doc.append(Doc::space())
209                .append(build_expr(between_expr.start().unwrap()))
210                .append(Doc::space())
211                .append(Doc::text("and"))
212                .append(Doc::space())
213                .append(build_expr(between_expr.end().unwrap()))
214        }
215        ast::Expr::BinExpr(bin_expr) => build_expr(bin_expr.lhs().unwrap())
216            .append(Doc::space())
217            .append(build_op(bin_expr.op().unwrap()))
218            .append(Doc::space())
219            .append(build_expr(bin_expr.rhs().unwrap())),
220        // ast::Expr::CallExpr(call_expr) => todo!(),
221        // ast::Expr::CaseExpr(case_expr) => todo!(),
222        ast::Expr::CastExpr(cast_expr) => {
223            let mut doc = Doc::nil();
224            if cast_expr.colon_colon().is_some() {
225                doc = doc
226                    .append(build_expr(cast_expr.expr().unwrap()))
227                    .append(Doc::text("::"))
228                    .append(build_type(cast_expr.ty().unwrap()))
229            } else if cast_expr.as_token().is_some() {
230                if cast_expr.cast_token().is_some() {
231                    doc = doc.append(Doc::text("cast"))
232                } else if cast_expr.treat_token().is_some() {
233                    doc = doc.append(Doc::text("treat"))
234                }
235                doc = doc
236                    .append(Doc::text("("))
237                    .append(build_expr(cast_expr.expr().unwrap()))
238                    .append(Doc::space())
239                    .append(Doc::text("as"))
240                    .append(Doc::space())
241                    .append(build_type(cast_expr.ty().unwrap()))
242                    .append(Doc::text(")"))
243            } else {
244                doc = doc
245                    .append(build_type(cast_expr.ty().unwrap()))
246                    .append(Doc::space())
247                    .append(build_literal(cast_expr.literal().unwrap()))
248            }
249            doc
250        }
251        // ast::Expr::FieldExpr(field_expr) => todo!(),
252        // ast::Expr::IndexExpr(index_expr) => todo!(),
253        ast::Expr::Literal(literal) => build_literal(literal),
254        // ast::Expr::NameRef(name_ref) => todo!(),
255        // ast::Expr::ParenExpr(paren_expr) => todo!(),
256        ast::Expr::PostfixExpr(postfix_expr) => {
257            let expr = build_expr(postfix_expr.expr().unwrap());
258            let op = match postfix_expr.op().unwrap() {
259                ast::PostfixOp::AtLocal(_) => Doc::text("at local"),
260                ast::PostfixOp::IsNull(_) => Doc::text("isnull"),
261                ast::PostfixOp::NotNull(_) => Doc::text("notnull"),
262                ast::PostfixOp::IsJson(n) => {
263                    let mut doc = Doc::text("is json");
264                    if let Some(clause) = n.json_keys_unique_clause() {
265                        doc = doc
266                            .append(Doc::space())
267                            .append(build_json_keys_unique_clause(clause));
268                    }
269                    doc
270                }
271                ast::PostfixOp::IsJsonArray(n) => {
272                    let mut doc = Doc::text("is json array");
273                    if let Some(clause) = n.json_keys_unique_clause() {
274                        doc = doc
275                            .append(Doc::space())
276                            .append(build_json_keys_unique_clause(clause));
277                    }
278                    doc
279                }
280                ast::PostfixOp::IsJsonObject(n) => {
281                    let mut doc = Doc::text("is json object");
282                    if let Some(clause) = n.json_keys_unique_clause() {
283                        doc = doc
284                            .append(Doc::space())
285                            .append(build_json_keys_unique_clause(clause));
286                    }
287                    doc
288                }
289                ast::PostfixOp::IsJsonScalar(n) => {
290                    let mut doc = Doc::text("is json scalar");
291                    if let Some(clause) = n.json_keys_unique_clause() {
292                        doc = doc
293                            .append(Doc::space())
294                            .append(build_json_keys_unique_clause(clause));
295                    }
296                    doc
297                }
298                ast::PostfixOp::IsJsonValue(n) => {
299                    let mut doc = Doc::text("is json value");
300                    if let Some(clause) = n.json_keys_unique_clause() {
301                        doc = doc
302                            .append(Doc::space())
303                            .append(build_json_keys_unique_clause(clause));
304                    }
305                    doc
306                }
307                ast::PostfixOp::IsNormalized(n) => {
308                    let mut doc = Doc::text("is");
309                    if let Some(form) = n.unicode_normal_form() {
310                        doc = doc
311                            .append(Doc::space())
312                            .append(build_unicode_normal_form(form));
313                    }
314                    doc.append(Doc::space()).append(Doc::text("normalized"))
315                }
316                ast::PostfixOp::IsNotJson(n) => {
317                    let mut doc = Doc::text("is not json");
318                    if let Some(clause) = n.json_keys_unique_clause() {
319                        doc = doc
320                            .append(Doc::space())
321                            .append(build_json_keys_unique_clause(clause));
322                    }
323                    doc
324                }
325                ast::PostfixOp::IsNotJsonArray(n) => {
326                    let mut doc = Doc::text("is not json array");
327                    if let Some(clause) = n.json_keys_unique_clause() {
328                        doc = doc
329                            .append(Doc::space())
330                            .append(build_json_keys_unique_clause(clause));
331                    }
332                    doc
333                }
334                ast::PostfixOp::IsNotJsonObject(n) => {
335                    let mut doc = Doc::text("is not json object");
336                    if let Some(clause) = n.json_keys_unique_clause() {
337                        doc = doc
338                            .append(Doc::space())
339                            .append(build_json_keys_unique_clause(clause));
340                    }
341                    doc
342                }
343                ast::PostfixOp::IsNotJsonScalar(n) => {
344                    let mut doc = Doc::text("is not json scalar");
345                    if let Some(clause) = n.json_keys_unique_clause() {
346                        doc = doc
347                            .append(Doc::space())
348                            .append(build_json_keys_unique_clause(clause));
349                    }
350                    doc
351                }
352                ast::PostfixOp::IsNotJsonValue(n) => {
353                    let mut doc = Doc::text("is not json value");
354                    if let Some(clause) = n.json_keys_unique_clause() {
355                        doc = doc
356                            .append(Doc::space())
357                            .append(build_json_keys_unique_clause(clause));
358                    }
359                    doc
360                }
361                ast::PostfixOp::IsNotNormalized(n) => {
362                    let mut doc = Doc::text("is not");
363                    if let Some(form) = n.unicode_normal_form() {
364                        doc = doc
365                            .append(Doc::space())
366                            .append(build_unicode_normal_form(form));
367                    }
368                    doc.append(Doc::space()).append(Doc::text("normalized"))
369                }
370            };
371            expr.append(Doc::space()).append(op)
372        }
373        // ast::Expr::PrefixExpr(prefix_expr) => todo!(),
374        // ast::Expr::SliceExpr(slice_expr) => todo!(),
375        // ast::Expr::TupleExpr(tuple_expr) => todo!(),
376        _ => Doc::text(expr.syntax().to_string()),
377    }
378}
379
380fn build_json_keys_unique_clause<'a>(clause: ast::JsonKeysUniqueClause) -> Doc<'a> {
381    let prefix = if clause.with_token().is_some() {
382        "with"
383    } else {
384        "without"
385    };
386    Doc::text(prefix)
387        .append(Doc::space())
388        .append(Doc::text("unique"))
389        .append(Doc::space())
390        .append(Doc::text("keys"))
391}
392
393fn build_unicode_normal_form<'a>(form: ast::UnicodeNormalForm) -> Doc<'a> {
394    if form.nfc_token().is_some() {
395        Doc::text("nfc")
396    } else if form.nfd_token().is_some() {
397        Doc::text("nfd")
398    } else if form.nfkc_token().is_some() {
399        Doc::text("nfkc")
400    } else {
401        Doc::text("nfkd")
402    }
403}
404
405fn build_keyword_node<'a>(node: &SyntaxNode) -> Doc<'a> {
406    let mut docs: Vec<Doc<'a>> = vec![];
407    for el in node.children_with_tokens() {
408        match el {
409            rowan::NodeOrToken::Token(token) => match token.kind() {
410                SyntaxKind::WHITESPACE => continue,
411                SyntaxKind::COMMENT => {
412                    if !docs.is_empty() {
413                        docs.push(Doc::space());
414                    }
415                    docs.push(Doc::text(token.text().to_string()));
416                }
417                _ => {
418                    if !docs.is_empty() {
419                        docs.push(Doc::space());
420                    }
421                    docs.push(Doc::text(token.text().to_ascii_lowercase()));
422                }
423            },
424            rowan::NodeOrToken::Node(_) => (),
425        }
426    }
427    Doc::list(docs)
428}
429
430fn build_op<'a>(op: ast::BinOp) -> Doc<'a> {
431    match op {
432        ast::BinOp::And(_) => Doc::text("and"),
433        ast::BinOp::AtTimeZone(n) => build_keyword_node(n.syntax()),
434        ast::BinOp::Caret(_) => Doc::text("^"),
435        ast::BinOp::Collate(_) => Doc::text("collate"),
436        ast::BinOp::ColonColon(_) => Doc::text("::"),
437        ast::BinOp::ColonEq(_) => Doc::text(":="),
438        ast::BinOp::CustomOp(custom_op) => Doc::text(custom_op.syntax().to_string()),
439        ast::BinOp::Eq(_) => Doc::text("="),
440        ast::BinOp::FatArrow(_) => Doc::text("=>"),
441        ast::BinOp::Gteq(_) => Doc::text(">="),
442        ast::BinOp::Ilike(_) => Doc::text("ilike"),
443        ast::BinOp::In(_) => Doc::text("in"),
444        ast::BinOp::Is(_) => Doc::text("is"),
445        ast::BinOp::IsDistinctFrom(n) => build_keyword_node(n.syntax()),
446        ast::BinOp::IsNot(n) => build_keyword_node(n.syntax()),
447        ast::BinOp::IsNotDistinctFrom(n) => build_keyword_node(n.syntax()),
448        ast::BinOp::LAngle(_) => Doc::text("<"),
449        ast::BinOp::Like(_) => Doc::text("like"),
450        ast::BinOp::Lteq(_) => Doc::text("<="),
451        ast::BinOp::Minus(_) => Doc::text("-"),
452        ast::BinOp::Neq(_) => Doc::text("!="),
453        ast::BinOp::Neqb(_) => Doc::text("<>"),
454        ast::BinOp::NotIlike(n) => build_keyword_node(n.syntax()),
455        ast::BinOp::NotIn(n) => build_keyword_node(n.syntax()),
456        ast::BinOp::NotLike(n) => build_keyword_node(n.syntax()),
457        ast::BinOp::NotSimilarTo(n) => build_keyword_node(n.syntax()),
458        ast::BinOp::OperatorCall(op) => Doc::text(op.syntax().to_string()),
459        ast::BinOp::Or(_) => Doc::text("or"),
460        ast::BinOp::Overlaps(_) => Doc::text("overlaps"),
461        ast::BinOp::Percent(_) => Doc::text("%"),
462        ast::BinOp::Plus(_) => Doc::text("+"),
463        ast::BinOp::RAngle(_) => Doc::text(">"),
464        ast::BinOp::SimilarTo(n) => build_keyword_node(n.syntax()),
465        ast::BinOp::Slash(_) => Doc::text("/"),
466        ast::BinOp::Star(_) => Doc::text("*"),
467    }
468}
469
470fn build_literal<'a>(lit: ast::Literal) -> Doc<'a> {
471    let Some(kind) = lit.kind() else {
472        return Doc::nil();
473    };
474    match kind {
475        LitKind::Default(_) => Doc::text("default"),
476        LitKind::False(_) => Doc::text("false"),
477        LitKind::IntNumber(t) => Doc::text(t.text().to_string()),
478        LitKind::Null(_) => Doc::text("null"),
479        LitKind::NumericNumber(t) => Doc::text(t.text().to_string()),
480        LitKind::PositionalParam(t) => Doc::text(t.text().to_string()),
481        LitKind::True(_) => Doc::text("true"),
482        LitKind::BitString(_)
483        | LitKind::ByteString(_)
484        | LitKind::DollarQuotedString(_)
485        | LitKind::EscString(_)
486        | LitKind::String(_)
487        | LitKind::UnicodeEscString(_) => build_string_literal(&lit),
488    }
489}
490
491fn build_string_literal<'a>(lit: &ast::Literal) -> Doc<'a> {
492    let parts: Vec<Doc<'a>> = lit
493        .syntax()
494        .children_with_tokens()
495        .filter_map(|el| match el {
496            rowan::NodeOrToken::Token(t) if t.kind() != SyntaxKind::WHITESPACE => {
497                Some(Doc::text(format_string_token(&t)))
498            }
499            _ => None,
500        })
501        .collect();
502    Doc::list(Itertools::intersperse(parts.into_iter(), Doc::hard_line()).collect())
503}
504
505fn format_string_token(t: &SyntaxToken) -> String {
506    let text = t.text();
507    if matches!(
508        t.kind(),
509        SyntaxKind::STRING | SyntaxKind::DOLLAR_QUOTED_STRING
510    ) {
511        return text.to_string();
512    }
513    match text.find('\'') {
514        Some(idx) => {
515            let (prefix, rest) = text.split_at(idx);
516            let mut s = String::with_capacity(text.len());
517            s.push_str(&prefix.to_ascii_lowercase());
518            s.push_str(rest);
519            s
520        }
521        None => text.to_string(),
522    }
523}
524
525fn build_name<'a>(name: ast::Name) -> Doc<'a> {
526    Doc::text(quote_column_alias(&name.text()))
527}
528
529fn build_type<'a>(ty: ast::Type) -> Doc<'a> {
530    Doc::text(ty.syntax().to_string())
531}
532
533fn leading_comments_token<'a>(node: &SyntaxToken) -> Doc<'a> {
534    let mut doc = Doc::nil();
535    for next in node.siblings_with_tokens(Direction::Prev).skip(1) {
536        match next {
537            rowan::NodeOrToken::Node(_node) => {
538                break;
539            }
540            rowan::NodeOrToken::Token(token) => {
541                if token.kind() == SyntaxKind::COMMENT {
542                    doc = doc
543                        .append(Doc::text(token.text().to_string()))
544                        .append(Doc::space());
545                } else if token.kind() == SyntaxKind::WHITESPACE {
546                    continue;
547                } else {
548                    break;
549                }
550            }
551        }
552    }
553    doc
554}
555
556fn leading_comments<'a>(node: &SyntaxNode) -> Doc<'a> {
557    let mut doc = Doc::nil();
558    for next in node.siblings_with_tokens(Direction::Prev).skip(1) {
559        match next {
560            rowan::NodeOrToken::Node(_node) => {
561                break;
562            }
563            rowan::NodeOrToken::Token(token) => {
564                if token.kind() == SyntaxKind::COMMENT {
565                    let is_block = token.text().starts_with("--");
566                    doc = doc
567                        .append(Doc::text(token.text().to_string()))
568                        .append(if is_block {
569                            Doc::hard_line()
570                        } else {
571                            Doc::space()
572                        });
573                } else if token.kind() == SyntaxKind::WHITESPACE {
574                    continue;
575                } else {
576                    break;
577                }
578            }
579        }
580    }
581    doc
582}
583
584fn trailing_comments<'a>(node: &SyntaxNode) -> Doc<'a> {
585    let mut doc = Doc::nil();
586    for next in node.siblings_with_tokens(Direction::Next).skip(1) {
587        match next {
588            rowan::NodeOrToken::Node(_node) => {
589                break;
590            }
591            rowan::NodeOrToken::Token(token) => {
592                if token.kind() == SyntaxKind::COMMENT {
593                    doc = doc
594                        .append(Doc::space())
595                        .append(Doc::text(token.text().to_string()));
596                } else if token.kind() == SyntaxKind::WHITESPACE {
597                    continue;
598                } else {
599                    break;
600                }
601            }
602        }
603    }
604    doc
605}
606
607fn build_target<'a>(target: ast::Target) -> Option<Doc<'a>> {
608    let mut doc = leading_comments(target.syntax());
609
610    if target.star_token().is_some() {
611        return Some(doc.append(Doc::text("*")));
612    }
613    let expr = target.expr()?;
614    doc = doc.append(build_expr(expr));
615
616    if let Some(as_name) = target.as_name() {
617        if as_name.as_token().is_some() {
618            doc = doc.append(Doc::space()).append(Doc::text("as"))
619        }
620
621        if let Some(name) = as_name.name() {
622            doc = doc.append(Doc::space()).append(build_name(name));
623        }
624    }
625
626    doc = doc.append(trailing_comments(target.syntax()));
627
628    Some(doc)
629}
630
631pub fn fmt(text: &str) -> String {
632    let parse = ast::SourceFile::parse(text);
633    let file = parse.tree();
634    println!("{text}");
635    println!("---");
636    println!("{:#?}", file.syntax());
637    println!("---");
638    debug_assert_eq!(
639        parse.errors(),
640        vec![],
641        "should bail out when there's parse errors"
642    );
643    let doc = build_source_file(&file);
644    print(&doc, &PrintOptions::default())
645}