Skip to main content

squawk_fmt/
fmt.rs

1use itertools::Itertools;
2use rowan::Direction;
3use squawk_syntax::ast::{self, AstNode, LitKind};
4use squawk_syntax::quote::quote_column_alias;
5use squawk_syntax::{SyntaxKind, SyntaxNode, SyntaxToken};
6use tiny_pretty::Doc;
7use tiny_pretty::{PrintOptions, print};
8
9// TODO: anytime we have `syntax().to_string()`, it means we have to do more to
10// actually convert the data into the IR. to_string() is a temp hack
11
12fn build_source_file(source_file: &ast::SourceFile) -> Doc<'_> {
13    let mut doc = Doc::nil();
14    for el in source_file.syntax().children_with_tokens() {
15        match el {
16            rowan::NodeOrToken::Node(node) => {
17                if let Some(stmt) = ast::Stmt::cast(node) {
18                    match stmt {
19                        ast::Stmt::Select(select) => {
20                            doc = doc.append(build_select_doc(&select));
21                        }
22                        ast::Stmt::CreateTable(create_table) => {
23                            doc = doc.append(build_create_table(&create_table));
24                        }
25                        _ => (),
26                    }
27                }
28            }
29            rowan::NodeOrToken::Token(token) => {
30                if token.kind() == SyntaxKind::COMMENT {
31                    doc = doc.append(Doc::text(token.text().to_string()));
32                } else if token.kind() == SyntaxKind::WHITESPACE {
33                    // TODO: I think we can improve this
34                    let lines = token.text().lines().count();
35                    if lines >= 2 {
36                        doc = doc.append(Doc::empty_line()).append(Doc::empty_line());
37                    } else {
38                        doc = doc.append(Doc::empty_line());
39                    }
40                }
41            }
42        }
43    }
44    doc
45}
46
47fn build_create_table<'a>(create_table: &ast::CreateTable) -> Doc<'a> {
48    let mut doc = Doc::text("create")
49        .append(Doc::space())
50        .append(Doc::text("table"))
51        .append(Doc::space())
52        .append(Doc::text(
53            create_table.path().map(|x| x.syntax().to_string()).unwrap(),
54        ))
55        .append(Doc::text("("))
56        .append(
57            Doc::line_or_nil()
58                .append(Doc::list(
59                    Itertools::intersperse(
60                        create_table
61                            .table_arg_list()
62                            .unwrap()
63                            .args()
64                            .map(build_table_arg),
65                        Doc::text(",").append(Doc::hard_line()),
66                    )
67                    .collect(),
68                ))
69                .nest(2)
70                .append(Doc::line_or_nil())
71                .group(),
72        )
73        .append(Doc::text(")"));
74
75    doc = doc.append(build_semicolon(create_table.semicolon_token()));
76
77    doc
78}
79
80fn build_table_arg<'a>(create_table: ast::TableArg) -> Doc<'a> {
81    match create_table {
82        ast::TableArg::Column(column) => build_name(column.name().unwrap())
83            .append(Doc::space())
84            .append(Doc::text(column.ty().unwrap().syntax().to_string())),
85        ast::TableArg::LikeClause(_like_clause) => todo!(),
86        ast::TableArg::TableConstraint(_table_constraint) => todo!(),
87    }
88}
89
90fn build_select_doc<'a>(select: &ast::Select) -> Doc<'a> {
91    let mut doc = Doc::text("select").append(Doc::line_or_space());
92
93    if let Some(select_clause) = select.select_clause() {
94        if let Some(distinct_clause) = select_clause.distinct_clause() {
95            doc = doc.append(leading_comments(distinct_clause.syntax()));
96            doc = doc.append(Doc::text("distinct")).append(Doc::space());
97        }
98        if let Some(all_token) = select_clause.all_token() {
99            doc = doc.append(leading_comments_token(&all_token));
100            doc = doc.append(Doc::text("all")).append(Doc::space());
101        }
102        if let Some(target_list) = select_clause.target_list() {
103            doc = doc.append(leading_comments(target_list.syntax()));
104            doc = doc
105                .append(Doc::list(
106                    Itertools::intersperse(
107                        target_list.targets().flat_map(build_target),
108                        Doc::text(",").append(Doc::line_or_space()),
109                    )
110                    .collect(),
111                ))
112                .nest(2);
113        }
114    }
115
116    if let Some(from) = &select.from_clause() {
117        doc = doc.append(
118            Doc::line_or_space()
119                .append(Doc::text("from"))
120                .append(Doc::space())
121                .append(Doc::text(
122                    from.from_items().next().unwrap().syntax().to_string(),
123                )),
124        );
125    }
126
127    if let Some(group) = &select.group_by_clause() {
128        let mut group_doc = Doc::line_or_space().append(leading_comments(group.syntax()));
129        group_doc = group_doc.append(Doc::text("group")).append(Doc::space());
130        if let Some(by_token) = group.by_token() {
131            group_doc = group_doc.append(leading_comments_token(&by_token));
132        }
133        group_doc = group_doc.append(Doc::text("by")).append(Doc::space());
134        if let Some(list) = group.group_by_list() {
135            group_doc = group_doc.append(build_group_by_list(list));
136        }
137        doc = doc.append(group_doc);
138    }
139
140    doc = doc.append(build_semicolon(select.semicolon_token()));
141
142    doc.group()
143}
144
145fn build_group_by_list<'a>(list: ast::GroupByList) -> Doc<'a> {
146    leading_comments(list.syntax()).append(Doc::text(list.syntax().to_string()))
147}
148
149fn build_semicolon<'a>(semi: Option<SyntaxToken>) -> Doc<'a> {
150    let Some(semi) = semi else {
151        return Doc::nil();
152    };
153    let mut doc = Doc::nil();
154    let mut comments: Vec<SyntaxToken> = vec![];
155    for next in semi.siblings_with_tokens(Direction::Prev).skip(1) {
156        match next {
157            rowan::NodeOrToken::Node(_) => break,
158            rowan::NodeOrToken::Token(token) => {
159                if token.kind() == SyntaxKind::COMMENT {
160                    comments.push(token);
161                } else if token.kind() == SyntaxKind::WHITESPACE {
162                    continue;
163                } else {
164                    break;
165                }
166            }
167        }
168    }
169    for comment in comments.iter().rev() {
170        doc = doc.append(Doc::text(comment.text().to_string()));
171    }
172    doc.append(Doc::text(";"))
173}
174
175fn build_expr<'a>(expr: ast::Expr) -> Doc<'a> {
176    match expr {
177        ast::Expr::ArrayExpr(array_expr) => {
178            let mut doc = Doc::nil();
179
180            // nested parts of array expressions don't require the array token
181            if array_expr.array_token().is_some() {
182                doc = doc.append(Doc::text("array"));
183            };
184
185            if let Some(select) = array_expr.select() {
186                doc = doc
187                    .append(Doc::text("("))
188                    .append(build_select_doc(&select))
189                    .append(Doc::text(")"))
190            } else {
191                doc = doc
192                    .append(Doc::text("["))
193                    .append(Doc::list(
194                        Itertools::intersperse(
195                            array_expr.exprs().map(build_expr),
196                            Doc::text(",").append(Doc::space()),
197                        )
198                        .collect(),
199                    ))
200                    .append(Doc::text("]"));
201            }
202
203            doc
204        }
205        ast::Expr::BetweenExpr(between_expr) => {
206            let mut doc = build_expr(between_expr.target().unwrap());
207            if between_expr.not_token().is_some() {
208                doc = doc.append(Doc::space()).append(Doc::text("not"));
209            }
210            doc = doc.append(Doc::space()).append(Doc::text("between"));
211            if between_expr.asymmetric_token().is_some() {
212                doc = doc.append(Doc::space()).append(Doc::text("asymmetric"));
213            } else if between_expr.symmetric_token().is_some() {
214                doc = doc.append(Doc::space()).append(Doc::text("symmetric"));
215            }
216            doc.append(Doc::space())
217                .append(build_expr(between_expr.start().unwrap()))
218                .append(Doc::space())
219                .append(Doc::text("and"))
220                .append(Doc::space())
221                .append(build_expr(between_expr.end().unwrap()))
222        }
223        ast::Expr::BinExpr(bin_expr) => build_expr(bin_expr.lhs().unwrap())
224            .append(Doc::space())
225            .append(build_op(bin_expr.op().unwrap()))
226            .append(Doc::space())
227            .append(build_expr(bin_expr.rhs().unwrap())),
228        // ast::Expr::CallExpr(call_expr) => todo!(),
229        // ast::Expr::CaseExpr(case_expr) => todo!(),
230        ast::Expr::CastExpr(cast_expr) => {
231            let mut doc = Doc::nil();
232            if cast_expr.colon_colon().is_some() {
233                doc = doc
234                    .append(build_expr(cast_expr.expr().unwrap()))
235                    .append(Doc::text("::"))
236                    .append(build_type(cast_expr.ty().unwrap()))
237            } else if cast_expr.as_token().is_some() {
238                if cast_expr.cast_token().is_some() {
239                    doc = doc.append(Doc::text("cast"))
240                } else if cast_expr.treat_token().is_some() {
241                    doc = doc.append(Doc::text("treat"))
242                }
243                doc = doc
244                    .append(Doc::text("("))
245                    .append(build_expr(cast_expr.expr().unwrap()))
246                    .append(Doc::space())
247                    .append(Doc::text("as"))
248                    .append(Doc::space())
249                    .append(build_type(cast_expr.ty().unwrap()))
250                    .append(Doc::text(")"))
251            } else {
252                doc = doc
253                    .append(build_type(cast_expr.ty().unwrap()))
254                    .append(Doc::space())
255                    .append(build_literal(cast_expr.literal().unwrap()))
256            }
257            doc
258        }
259        // ast::Expr::FieldExpr(field_expr) => todo!(),
260        // ast::Expr::IndexExpr(index_expr) => todo!(),
261        ast::Expr::Literal(literal) => build_literal(literal),
262        // ast::Expr::NameRef(name_ref) => todo!(),
263        // ast::Expr::ParenExpr(paren_expr) => todo!(),
264        ast::Expr::PostfixExpr(postfix_expr) => {
265            let expr = build_expr(postfix_expr.expr().unwrap());
266            let op = match postfix_expr.op().unwrap() {
267                ast::PostfixOp::AtLocal(_) => Doc::text("at local"),
268                ast::PostfixOp::IsNull(_) => Doc::text("isnull"),
269                ast::PostfixOp::NotNull(_) => Doc::text("notnull"),
270                ast::PostfixOp::IsJson(n) => {
271                    let mut doc = Doc::text("is json");
272                    if let Some(clause) = n.json_keys_unique_clause() {
273                        doc = doc
274                            .append(Doc::space())
275                            .append(build_json_keys_unique_clause(clause));
276                    }
277                    doc
278                }
279                ast::PostfixOp::IsJsonArray(n) => {
280                    let mut doc = Doc::text("is json array");
281                    if let Some(clause) = n.json_keys_unique_clause() {
282                        doc = doc
283                            .append(Doc::space())
284                            .append(build_json_keys_unique_clause(clause));
285                    }
286                    doc
287                }
288                ast::PostfixOp::IsJsonObject(n) => {
289                    let mut doc = Doc::text("is json object");
290                    if let Some(clause) = n.json_keys_unique_clause() {
291                        doc = doc
292                            .append(Doc::space())
293                            .append(build_json_keys_unique_clause(clause));
294                    }
295                    doc
296                }
297                ast::PostfixOp::IsJsonScalar(n) => {
298                    let mut doc = Doc::text("is json scalar");
299                    if let Some(clause) = n.json_keys_unique_clause() {
300                        doc = doc
301                            .append(Doc::space())
302                            .append(build_json_keys_unique_clause(clause));
303                    }
304                    doc
305                }
306                ast::PostfixOp::IsJsonValue(n) => {
307                    let mut doc = Doc::text("is json value");
308                    if let Some(clause) = n.json_keys_unique_clause() {
309                        doc = doc
310                            .append(Doc::space())
311                            .append(build_json_keys_unique_clause(clause));
312                    }
313                    doc
314                }
315                ast::PostfixOp::IsNormalized(n) => {
316                    let mut doc = Doc::text("is");
317                    if let Some(form) = n.unicode_normal_form() {
318                        doc = doc
319                            .append(Doc::space())
320                            .append(build_unicode_normal_form(form));
321                    }
322                    doc.append(Doc::space()).append(Doc::text("normalized"))
323                }
324                ast::PostfixOp::IsNotJson(n) => {
325                    let mut doc = Doc::text("is not json");
326                    if let Some(clause) = n.json_keys_unique_clause() {
327                        doc = doc
328                            .append(Doc::space())
329                            .append(build_json_keys_unique_clause(clause));
330                    }
331                    doc
332                }
333                ast::PostfixOp::IsNotJsonArray(n) => {
334                    let mut doc = Doc::text("is not json array");
335                    if let Some(clause) = n.json_keys_unique_clause() {
336                        doc = doc
337                            .append(Doc::space())
338                            .append(build_json_keys_unique_clause(clause));
339                    }
340                    doc
341                }
342                ast::PostfixOp::IsNotJsonObject(n) => {
343                    let mut doc = Doc::text("is not json object");
344                    if let Some(clause) = n.json_keys_unique_clause() {
345                        doc = doc
346                            .append(Doc::space())
347                            .append(build_json_keys_unique_clause(clause));
348                    }
349                    doc
350                }
351                ast::PostfixOp::IsNotJsonScalar(n) => {
352                    let mut doc = Doc::text("is not json scalar");
353                    if let Some(clause) = n.json_keys_unique_clause() {
354                        doc = doc
355                            .append(Doc::space())
356                            .append(build_json_keys_unique_clause(clause));
357                    }
358                    doc
359                }
360                ast::PostfixOp::IsNotJsonValue(n) => {
361                    let mut doc = Doc::text("is not json value");
362                    if let Some(clause) = n.json_keys_unique_clause() {
363                        doc = doc
364                            .append(Doc::space())
365                            .append(build_json_keys_unique_clause(clause));
366                    }
367                    doc
368                }
369                ast::PostfixOp::IsNotNormalized(n) => {
370                    let mut doc = Doc::text("is not");
371                    if let Some(form) = n.unicode_normal_form() {
372                        doc = doc
373                            .append(Doc::space())
374                            .append(build_unicode_normal_form(form));
375                    }
376                    doc.append(Doc::space()).append(Doc::text("normalized"))
377                }
378            };
379            expr.append(Doc::space()).append(op)
380        }
381        // ast::Expr::PrefixExpr(prefix_expr) => todo!(),
382        // ast::Expr::SliceExpr(slice_expr) => todo!(),
383        // ast::Expr::TupleExpr(tuple_expr) => todo!(),
384        _ => Doc::text(expr.syntax().to_string()),
385    }
386}
387
388fn build_json_keys_unique_clause<'a>(clause: ast::JsonKeysUniqueClause) -> Doc<'a> {
389    let prefix = if clause.with_token().is_some() {
390        "with"
391    } else {
392        "without"
393    };
394    Doc::text(prefix)
395        .append(Doc::space())
396        .append(Doc::text("unique"))
397        .append(Doc::space())
398        .append(Doc::text("keys"))
399}
400
401fn build_unicode_normal_form<'a>(form: ast::UnicodeNormalForm) -> Doc<'a> {
402    if form.nfc_token().is_some() {
403        Doc::text("nfc")
404    } else if form.nfd_token().is_some() {
405        Doc::text("nfd")
406    } else if form.nfkc_token().is_some() {
407        Doc::text("nfkc")
408    } else {
409        Doc::text("nfkd")
410    }
411}
412
413fn build_keyword_node<'a>(node: &SyntaxNode) -> Doc<'a> {
414    let mut docs: Vec<Doc<'a>> = vec![];
415    for el in node.children_with_tokens() {
416        match el {
417            rowan::NodeOrToken::Token(token) => match token.kind() {
418                SyntaxKind::WHITESPACE => continue,
419                SyntaxKind::COMMENT => {
420                    if !docs.is_empty() {
421                        docs.push(Doc::space());
422                    }
423                    docs.push(Doc::text(token.text().to_string()));
424                }
425                _ => {
426                    if !docs.is_empty() {
427                        docs.push(Doc::space());
428                    }
429                    docs.push(Doc::text(token.text().to_ascii_lowercase()));
430                }
431            },
432            rowan::NodeOrToken::Node(_) => (),
433        }
434    }
435    Doc::list(docs)
436}
437
438fn build_op<'a>(op: ast::BinOp) -> Doc<'a> {
439    match op {
440        ast::BinOp::And(_) => Doc::text("and"),
441        ast::BinOp::AtTimeZone(n) => build_keyword_node(n.syntax()),
442        ast::BinOp::Caret(_) => Doc::text("^"),
443        ast::BinOp::Collate(_) => Doc::text("collate"),
444        ast::BinOp::ColonColon(_) => Doc::text("::"),
445        ast::BinOp::ColonEq(_) => Doc::text(":="),
446        ast::BinOp::CustomOp(custom_op) => Doc::text(custom_op.syntax().to_string()),
447        ast::BinOp::Eq(_) => Doc::text("="),
448        ast::BinOp::Escape(_) => Doc::text("escape"),
449        ast::BinOp::FatArrow(_) => Doc::text("=>"),
450        ast::BinOp::Gteq(_) => Doc::text(">="),
451        ast::BinOp::Ilike(_) => Doc::text("ilike"),
452        ast::BinOp::In(_) => Doc::text("in"),
453        ast::BinOp::Is(_) => Doc::text("is"),
454        ast::BinOp::IsDistinctFrom(n) => build_keyword_node(n.syntax()),
455        ast::BinOp::IsNot(n) => build_keyword_node(n.syntax()),
456        ast::BinOp::IsNotDistinctFrom(n) => build_keyword_node(n.syntax()),
457        ast::BinOp::LAngle(_) => Doc::text("<"),
458        ast::BinOp::Like(_) => Doc::text("like"),
459        ast::BinOp::Lteq(_) => Doc::text("<="),
460        ast::BinOp::Minus(_) => Doc::text("-"),
461        ast::BinOp::Neq(_) => Doc::text("!="),
462        ast::BinOp::Neqb(_) => Doc::text("<>"),
463        ast::BinOp::NotIlike(n) => build_keyword_node(n.syntax()),
464        ast::BinOp::NotIn(n) => build_keyword_node(n.syntax()),
465        ast::BinOp::NotLike(n) => build_keyword_node(n.syntax()),
466        ast::BinOp::NotSimilarTo(n) => build_keyword_node(n.syntax()),
467        ast::BinOp::OperatorCall(op) => Doc::text(op.syntax().to_string()),
468        ast::BinOp::Or(_) => Doc::text("or"),
469        ast::BinOp::Overlaps(_) => Doc::text("overlaps"),
470        ast::BinOp::Percent(_) => Doc::text("%"),
471        ast::BinOp::Plus(_) => Doc::text("+"),
472        ast::BinOp::RAngle(_) => Doc::text(">"),
473        ast::BinOp::SimilarTo(n) => build_keyword_node(n.syntax()),
474        ast::BinOp::Slash(_) => Doc::text("/"),
475        ast::BinOp::Star(_) => Doc::text("*"),
476    }
477}
478
479fn build_literal<'a>(lit: ast::Literal) -> Doc<'a> {
480    let Some(kind) = lit.kind() else {
481        return Doc::nil();
482    };
483    match kind {
484        LitKind::Default(_) => Doc::text("default"),
485        LitKind::False(_) => Doc::text("false"),
486        LitKind::IntNumber(t) => Doc::text(t.text().to_string()),
487        LitKind::Null(_) => Doc::text("null"),
488        LitKind::NumericNumber(t) => Doc::text(t.text().to_string()),
489        LitKind::PositionalParam(t) => Doc::text(t.text().to_string()),
490        LitKind::True(_) => Doc::text("true"),
491        LitKind::BitString(_)
492        | LitKind::ByteString(_)
493        | LitKind::DollarQuotedString(_)
494        | LitKind::EscString(_)
495        | LitKind::String(_)
496        | LitKind::UnicodeEscString(_) => build_string_literal(&lit),
497    }
498}
499
500fn build_string_literal<'a>(lit: &ast::Literal) -> Doc<'a> {
501    let parts: Vec<Doc<'a>> = lit
502        .syntax()
503        .children_with_tokens()
504        .filter_map(|el| match el {
505            rowan::NodeOrToken::Token(t) if t.kind() != SyntaxKind::WHITESPACE => {
506                Some(Doc::text(format_string_token(&t)))
507            }
508            _ => None,
509        })
510        .collect();
511    Doc::list(Itertools::intersperse(parts.into_iter(), Doc::hard_line()).collect())
512}
513
514fn format_string_token(t: &SyntaxToken) -> String {
515    let text = t.text();
516    if matches!(
517        t.kind(),
518        SyntaxKind::STRING | SyntaxKind::DOLLAR_QUOTED_STRING
519    ) {
520        return text.to_string();
521    }
522    match text.find('\'') {
523        Some(idx) => {
524            let (prefix, rest) = text.split_at(idx);
525            let mut s = String::with_capacity(text.len());
526            s.push_str(&prefix.to_ascii_lowercase());
527            s.push_str(rest);
528            s
529        }
530        None => text.to_string(),
531    }
532}
533
534fn build_name<'a>(name: ast::Name) -> Doc<'a> {
535    Doc::text(quote_column_alias(&name.text()))
536}
537
538fn build_type<'a>(ty: ast::Type) -> Doc<'a> {
539    Doc::text(ty.syntax().to_string())
540}
541
542fn leading_comments_token<'a>(node: &SyntaxToken) -> Doc<'a> {
543    let mut doc = Doc::nil();
544    for next in node.siblings_with_tokens(Direction::Prev).skip(1) {
545        match next {
546            rowan::NodeOrToken::Node(_node) => {
547                break;
548            }
549            rowan::NodeOrToken::Token(token) => {
550                if token.kind() == SyntaxKind::COMMENT {
551                    doc = doc
552                        .append(Doc::text(token.text().to_string()))
553                        .append(Doc::space());
554                } else if token.kind() == SyntaxKind::WHITESPACE {
555                    continue;
556                } else {
557                    break;
558                }
559            }
560        }
561    }
562    doc
563}
564
565fn leading_comments<'a>(node: &SyntaxNode) -> Doc<'a> {
566    let mut doc = Doc::nil();
567    for next in node.siblings_with_tokens(Direction::Prev).skip(1) {
568        match next {
569            rowan::NodeOrToken::Node(_node) => {
570                break;
571            }
572            rowan::NodeOrToken::Token(token) => {
573                if token.kind() == SyntaxKind::COMMENT {
574                    let is_block = token.text().starts_with("--");
575                    doc = doc
576                        .append(Doc::text(token.text().to_string()))
577                        .append(if is_block {
578                            Doc::hard_line()
579                        } else {
580                            Doc::space()
581                        });
582                } else if token.kind() == SyntaxKind::WHITESPACE {
583                    continue;
584                } else {
585                    break;
586                }
587            }
588        }
589    }
590    doc
591}
592
593fn trailing_comments<'a>(node: &SyntaxNode) -> Doc<'a> {
594    let mut doc = Doc::nil();
595    for next in node.siblings_with_tokens(Direction::Next).skip(1) {
596        match next {
597            rowan::NodeOrToken::Node(_node) => {
598                break;
599            }
600            rowan::NodeOrToken::Token(token) => {
601                if token.kind() == SyntaxKind::COMMENT {
602                    doc = doc
603                        .append(Doc::space())
604                        .append(Doc::text(token.text().to_string()));
605                } else if token.kind() == SyntaxKind::WHITESPACE {
606                    continue;
607                } else {
608                    break;
609                }
610            }
611        }
612    }
613    doc
614}
615
616fn build_target<'a>(target: ast::Target) -> Option<Doc<'a>> {
617    let mut doc = leading_comments(target.syntax());
618
619    if target.star_token().is_some() {
620        return Some(doc.append(Doc::text("*")));
621    }
622    let expr = target.expr()?;
623    doc = doc.append(build_expr(expr));
624
625    if let Some(as_name) = target.as_name() {
626        if as_name.as_token().is_some() {
627            doc = doc.append(Doc::space()).append(Doc::text("as"))
628        }
629
630        if let Some(name) = as_name.name() {
631            doc = doc.append(Doc::space()).append(build_name(name));
632        }
633    }
634
635    doc = doc.append(trailing_comments(target.syntax()));
636
637    Some(doc)
638}
639
640pub fn fmt(text: &str) -> String {
641    let parse = ast::SourceFile::parse(text);
642    let file = parse.tree();
643    println!("{text}");
644    println!("---");
645    println!("{:#?}", file.syntax());
646    println!("---");
647    debug_assert_eq!(
648        parse.errors(),
649        vec![],
650        "should bail out when there's parse errors"
651    );
652    let doc = build_source_file(&file);
653    print(&doc, &PrintOptions::default())
654}