Skip to main content

squawk_fmt/
fmt.rs

1use itertools::Itertools;
2use rowan::Direction;
3use squawk_syntax::ast::{self, AstNode, LitKind};
4use squawk_syntax::quote::quote_column_alias;
5use squawk_syntax::{SyntaxKind, SyntaxNode, SyntaxToken};
6use tiny_pretty::Doc;
7use tiny_pretty::{PrintOptions, print};
8
9// TODO: anytime we have `syntax().to_string()`, it means we have to do more to
10// actually convert the data into the IR. to_string() is a temp hack
11
12fn build_source_file(source_file: &ast::SourceFile) -> Doc<'_> {
13    let mut doc = Doc::nil();
14    for el in source_file.syntax().children_with_tokens() {
15        match el {
16            rowan::NodeOrToken::Node(node) => {
17                if let Some(stmt) = ast::Stmt::cast(node) {
18                    match stmt {
19                        ast::Stmt::Select(select) => {
20                            doc = doc.append(build_select_doc(&select));
21                        }
22                        ast::Stmt::CreateTable(create_table) => {
23                            doc = doc.append(build_create_table(&create_table));
24                        }
25                        _ => (),
26                    }
27                }
28            }
29            rowan::NodeOrToken::Token(token) => {
30                if token.kind() == SyntaxKind::COMMENT {
31                    doc = doc.append(Doc::text(token.text().to_string()));
32                } else if token.kind() == SyntaxKind::WHITESPACE {
33                    // TODO: I think we can improve this
34                    let lines = token.text().lines().count();
35                    if lines >= 2 {
36                        doc = doc.append(Doc::empty_line()).append(Doc::empty_line());
37                    } else {
38                        doc = doc.append(Doc::empty_line());
39                    }
40                }
41            }
42        }
43    }
44    doc
45}
46
47fn build_create_table<'a>(create_table: &ast::CreateTable) -> Doc<'a> {
48    let mut doc = Doc::text("create")
49        .append(Doc::space())
50        .append(Doc::text("table"))
51        .append(Doc::space())
52        .append(Doc::text(
53            create_table.path().map(|x| x.syntax().to_string()).unwrap(),
54        ))
55        .append(Doc::text("("))
56        .append(
57            Doc::line_or_nil()
58                .append(Doc::list(
59                    Itertools::intersperse(
60                        create_table
61                            .table_arg_list()
62                            .unwrap()
63                            .args()
64                            .map(build_table_arg),
65                        Doc::text(",").append(Doc::hard_line()),
66                    )
67                    .collect(),
68                ))
69                .nest(2)
70                .append(Doc::line_or_nil())
71                .group(),
72        )
73        .append(Doc::text(")"));
74
75    doc = doc.append(build_semicolon(create_table.semicolon_token()));
76
77    doc
78}
79
80fn build_table_arg<'a>(create_table: ast::TableArg) -> Doc<'a> {
81    match create_table {
82        ast::TableArg::Column(column) => build_name(column.name().unwrap())
83            .append(Doc::space())
84            .append(Doc::text(column.ty().unwrap().syntax().to_string())),
85        ast::TableArg::LikeClause(_like_clause) => todo!(),
86        ast::TableArg::TableConstraint(_table_constraint) => todo!(),
87    }
88}
89
90fn build_select_doc<'a>(select: &ast::Select) -> Doc<'a> {
91    let mut doc = Doc::text("select").append(Doc::line_or_space());
92
93    if let Some(select_clause) = select.select_clause() {
94        if let Some(distinct_clause) = select_clause.distinct_clause() {
95            doc = doc.append(leading_comments(distinct_clause.syntax()));
96            doc = doc.append(Doc::text("distinct")).append(Doc::space());
97        }
98        if let Some(all_token) = select_clause.all_token() {
99            doc = doc.append(leading_comments_token(&all_token));
100            doc = doc.append(Doc::text("all")).append(Doc::space());
101        }
102        if let Some(target_list) = select_clause.target_list() {
103            doc = doc.append(leading_comments(target_list.syntax()));
104            doc = doc
105                .append(Doc::list(
106                    Itertools::intersperse(
107                        target_list.targets().flat_map(build_target),
108                        Doc::text(",").append(Doc::line_or_space()),
109                    )
110                    .collect(),
111                ))
112                .nest(2);
113        }
114    }
115
116    if let Some(from) = &select.from_clause() {
117        doc = doc.append(
118            Doc::line_or_space()
119                .append(Doc::text("from"))
120                .append(Doc::space())
121                .append(Doc::text(
122                    from.from_items().next().unwrap().syntax().to_string(),
123                )),
124        );
125    }
126
127    if let Some(group) = &select.group_by_clause() {
128        let mut group_doc = Doc::line_or_space().append(leading_comments(group.syntax()));
129        group_doc = group_doc.append(Doc::text("group")).append(Doc::space());
130        if let Some(by_token) = group.by_token() {
131            group_doc = group_doc.append(leading_comments_token(&by_token));
132        }
133        group_doc = group_doc.append(Doc::text("by")).append(Doc::space());
134        if let Some(list) = group.group_by_list() {
135            group_doc = group_doc.append(build_group_by_list(list));
136        }
137        doc = doc.append(group_doc);
138    }
139
140    doc = doc.append(build_semicolon(select.semicolon_token()));
141
142    doc.group()
143}
144
145fn build_group_by_list<'a>(list: ast::GroupByList) -> Doc<'a> {
146    leading_comments(list.syntax()).append(Doc::text(list.syntax().to_string()))
147}
148
149fn build_semicolon<'a>(semi: Option<SyntaxToken>) -> Doc<'a> {
150    let Some(semi) = semi else {
151        return Doc::nil();
152    };
153    let mut doc = Doc::nil();
154    let mut comments: Vec<SyntaxToken> = vec![];
155    for next in semi.siblings_with_tokens(Direction::Prev).skip(1) {
156        match next {
157            rowan::NodeOrToken::Node(_) => break,
158            rowan::NodeOrToken::Token(token) => {
159                if token.kind() == SyntaxKind::COMMENT {
160                    comments.push(token);
161                } else if token.kind() == SyntaxKind::WHITESPACE {
162                    continue;
163                } else {
164                    break;
165                }
166            }
167        }
168    }
169    for comment in comments.iter().rev() {
170        doc = doc.append(Doc::text(comment.text().to_string()));
171    }
172    doc.append(Doc::text(";"))
173}
174
175fn build_expr<'a>(expr: ast::Expr) -> Doc<'a> {
176    match expr {
177        ast::Expr::ArrayExpr(array_expr) => {
178            let mut doc = Doc::nil();
179
180            // nested parts of array expressions don't require the array token
181            if array_expr.array_token().is_some() {
182                doc = doc.append(Doc::text("array"));
183            };
184
185            if let Some(select) = array_expr.select() {
186                doc = doc
187                    .append(Doc::text("("))
188                    .append(build_select_doc(&select))
189                    .append(Doc::text(")"))
190            } else {
191                doc = doc
192                    .append(Doc::text("["))
193                    .append(Doc::list(
194                        Itertools::intersperse(
195                            array_expr.exprs().map(build_expr),
196                            Doc::text(",").append(Doc::space()),
197                        )
198                        .collect(),
199                    ))
200                    .append(Doc::text("]"));
201            }
202
203            doc
204        }
205        ast::Expr::BetweenExpr(between_expr) => {
206            let mut doc = build_expr(between_expr.target().unwrap());
207            if between_expr.not_token().is_some() {
208                doc = doc.append(Doc::space()).append(Doc::text("not"));
209            }
210            doc = doc.append(Doc::space()).append(Doc::text("between"));
211            if between_expr.symmetric_token().is_some() {
212                doc = doc.append(Doc::space()).append(Doc::text("symmetric"));
213            }
214            doc.append(Doc::space())
215                .append(build_expr(between_expr.start().unwrap()))
216                .append(Doc::space())
217                .append(Doc::text("and"))
218                .append(Doc::space())
219                .append(build_expr(between_expr.end().unwrap()))
220        }
221        ast::Expr::BinExpr(bin_expr) => build_expr(bin_expr.lhs().unwrap())
222            .append(Doc::space())
223            .append(build_op(bin_expr.op().unwrap()))
224            .append(Doc::space())
225            .append(build_expr(bin_expr.rhs().unwrap())),
226        // ast::Expr::CallExpr(call_expr) => todo!(),
227        // ast::Expr::CaseExpr(case_expr) => todo!(),
228        ast::Expr::CastExpr(cast_expr) => {
229            let mut doc = Doc::nil();
230            if cast_expr.colon_colon().is_some() {
231                doc = doc
232                    .append(build_expr(cast_expr.expr().unwrap()))
233                    .append(Doc::text("::"))
234                    .append(build_type(cast_expr.ty().unwrap()))
235            } else if cast_expr.as_token().is_some() {
236                if cast_expr.cast_token().is_some() {
237                    doc = doc.append(Doc::text("cast"))
238                } else if cast_expr.treat_token().is_some() {
239                    doc = doc.append(Doc::text("treat"))
240                }
241                doc = doc
242                    .append(Doc::text("("))
243                    .append(build_expr(cast_expr.expr().unwrap()))
244                    .append(Doc::space())
245                    .append(Doc::text("as"))
246                    .append(Doc::space())
247                    .append(build_type(cast_expr.ty().unwrap()))
248                    .append(Doc::text(")"))
249            } else {
250                doc = doc
251                    .append(build_type(cast_expr.ty().unwrap()))
252                    .append(Doc::space())
253                    .append(build_literal(cast_expr.literal().unwrap()))
254            }
255            doc
256        }
257        // ast::Expr::FieldExpr(field_expr) => todo!(),
258        // ast::Expr::IndexExpr(index_expr) => todo!(),
259        ast::Expr::Literal(literal) => build_literal(literal),
260        // ast::Expr::NameRef(name_ref) => todo!(),
261        // ast::Expr::ParenExpr(paren_expr) => todo!(),
262        ast::Expr::PostfixExpr(postfix_expr) => {
263            let expr = build_expr(postfix_expr.expr().unwrap());
264            let op = match postfix_expr.op().unwrap() {
265                ast::PostfixOp::AtLocal(_) => Doc::text("at local"),
266                ast::PostfixOp::IsNull(_) => Doc::text("isnull"),
267                ast::PostfixOp::NotNull(_) => Doc::text("notnull"),
268                ast::PostfixOp::IsJson(n) => {
269                    let mut doc = Doc::text("is json");
270                    if let Some(clause) = n.json_keys_unique_clause() {
271                        doc = doc
272                            .append(Doc::space())
273                            .append(build_json_keys_unique_clause(clause));
274                    }
275                    doc
276                }
277                ast::PostfixOp::IsJsonArray(n) => {
278                    let mut doc = Doc::text("is json array");
279                    if let Some(clause) = n.json_keys_unique_clause() {
280                        doc = doc
281                            .append(Doc::space())
282                            .append(build_json_keys_unique_clause(clause));
283                    }
284                    doc
285                }
286                ast::PostfixOp::IsJsonObject(n) => {
287                    let mut doc = Doc::text("is json object");
288                    if let Some(clause) = n.json_keys_unique_clause() {
289                        doc = doc
290                            .append(Doc::space())
291                            .append(build_json_keys_unique_clause(clause));
292                    }
293                    doc
294                }
295                ast::PostfixOp::IsJsonScalar(n) => {
296                    let mut doc = Doc::text("is json scalar");
297                    if let Some(clause) = n.json_keys_unique_clause() {
298                        doc = doc
299                            .append(Doc::space())
300                            .append(build_json_keys_unique_clause(clause));
301                    }
302                    doc
303                }
304                ast::PostfixOp::IsJsonValue(n) => {
305                    let mut doc = Doc::text("is json value");
306                    if let Some(clause) = n.json_keys_unique_clause() {
307                        doc = doc
308                            .append(Doc::space())
309                            .append(build_json_keys_unique_clause(clause));
310                    }
311                    doc
312                }
313                ast::PostfixOp::IsNormalized(n) => {
314                    let mut doc = Doc::text("is");
315                    if let Some(form) = n.unicode_normal_form() {
316                        doc = doc
317                            .append(Doc::space())
318                            .append(build_unicode_normal_form(form));
319                    }
320                    doc.append(Doc::space()).append(Doc::text("normalized"))
321                }
322                ast::PostfixOp::IsNotJson(n) => {
323                    let mut doc = Doc::text("is not json");
324                    if let Some(clause) = n.json_keys_unique_clause() {
325                        doc = doc
326                            .append(Doc::space())
327                            .append(build_json_keys_unique_clause(clause));
328                    }
329                    doc
330                }
331                ast::PostfixOp::IsNotJsonArray(n) => {
332                    let mut doc = Doc::text("is not json array");
333                    if let Some(clause) = n.json_keys_unique_clause() {
334                        doc = doc
335                            .append(Doc::space())
336                            .append(build_json_keys_unique_clause(clause));
337                    }
338                    doc
339                }
340                ast::PostfixOp::IsNotJsonObject(n) => {
341                    let mut doc = Doc::text("is not json object");
342                    if let Some(clause) = n.json_keys_unique_clause() {
343                        doc = doc
344                            .append(Doc::space())
345                            .append(build_json_keys_unique_clause(clause));
346                    }
347                    doc
348                }
349                ast::PostfixOp::IsNotJsonScalar(n) => {
350                    let mut doc = Doc::text("is not json scalar");
351                    if let Some(clause) = n.json_keys_unique_clause() {
352                        doc = doc
353                            .append(Doc::space())
354                            .append(build_json_keys_unique_clause(clause));
355                    }
356                    doc
357                }
358                ast::PostfixOp::IsNotJsonValue(n) => {
359                    let mut doc = Doc::text("is not json value");
360                    if let Some(clause) = n.json_keys_unique_clause() {
361                        doc = doc
362                            .append(Doc::space())
363                            .append(build_json_keys_unique_clause(clause));
364                    }
365                    doc
366                }
367                ast::PostfixOp::IsNotNormalized(n) => {
368                    let mut doc = Doc::text("is not");
369                    if let Some(form) = n.unicode_normal_form() {
370                        doc = doc
371                            .append(Doc::space())
372                            .append(build_unicode_normal_form(form));
373                    }
374                    doc.append(Doc::space()).append(Doc::text("normalized"))
375                }
376            };
377            expr.append(Doc::space()).append(op)
378        }
379        // ast::Expr::PrefixExpr(prefix_expr) => todo!(),
380        // ast::Expr::SliceExpr(slice_expr) => todo!(),
381        // ast::Expr::TupleExpr(tuple_expr) => todo!(),
382        _ => Doc::text(expr.syntax().to_string()),
383    }
384}
385
386fn build_json_keys_unique_clause<'a>(clause: ast::JsonKeysUniqueClause) -> Doc<'a> {
387    let prefix = if clause.with_token().is_some() {
388        "with"
389    } else {
390        "without"
391    };
392    Doc::text(prefix)
393        .append(Doc::space())
394        .append(Doc::text("unique"))
395        .append(Doc::space())
396        .append(Doc::text("keys"))
397}
398
399fn build_unicode_normal_form<'a>(form: ast::UnicodeNormalForm) -> Doc<'a> {
400    if form.nfc_token().is_some() {
401        Doc::text("nfc")
402    } else if form.nfd_token().is_some() {
403        Doc::text("nfd")
404    } else if form.nfkc_token().is_some() {
405        Doc::text("nfkc")
406    } else {
407        Doc::text("nfkd")
408    }
409}
410
411fn build_keyword_node<'a>(node: &SyntaxNode) -> Doc<'a> {
412    let mut docs: Vec<Doc<'a>> = vec![];
413    for el in node.children_with_tokens() {
414        match el {
415            rowan::NodeOrToken::Token(token) => match token.kind() {
416                SyntaxKind::WHITESPACE => continue,
417                SyntaxKind::COMMENT => {
418                    if !docs.is_empty() {
419                        docs.push(Doc::space());
420                    }
421                    docs.push(Doc::text(token.text().to_string()));
422                }
423                _ => {
424                    if !docs.is_empty() {
425                        docs.push(Doc::space());
426                    }
427                    docs.push(Doc::text(token.text().to_ascii_lowercase()));
428                }
429            },
430            rowan::NodeOrToken::Node(_) => (),
431        }
432    }
433    Doc::list(docs)
434}
435
436fn build_op<'a>(op: ast::BinOp) -> Doc<'a> {
437    match op {
438        ast::BinOp::And(_) => Doc::text("and"),
439        ast::BinOp::AtTimeZone(n) => build_keyword_node(n.syntax()),
440        ast::BinOp::Caret(_) => Doc::text("^"),
441        ast::BinOp::Collate(_) => Doc::text("collate"),
442        ast::BinOp::ColonColon(_) => Doc::text("::"),
443        ast::BinOp::ColonEq(_) => Doc::text(":="),
444        ast::BinOp::CustomOp(custom_op) => Doc::text(custom_op.syntax().to_string()),
445        ast::BinOp::Eq(_) => Doc::text("="),
446        ast::BinOp::FatArrow(_) => Doc::text("=>"),
447        ast::BinOp::Gteq(_) => Doc::text(">="),
448        ast::BinOp::Ilike(_) => Doc::text("ilike"),
449        ast::BinOp::In(_) => Doc::text("in"),
450        ast::BinOp::Is(_) => Doc::text("is"),
451        ast::BinOp::IsDistinctFrom(n) => build_keyword_node(n.syntax()),
452        ast::BinOp::IsNot(n) => build_keyword_node(n.syntax()),
453        ast::BinOp::IsNotDistinctFrom(n) => build_keyword_node(n.syntax()),
454        ast::BinOp::LAngle(_) => Doc::text("<"),
455        ast::BinOp::Like(_) => Doc::text("like"),
456        ast::BinOp::Lteq(_) => Doc::text("<="),
457        ast::BinOp::Minus(_) => Doc::text("-"),
458        ast::BinOp::Neq(_) => Doc::text("!="),
459        ast::BinOp::Neqb(_) => Doc::text("<>"),
460        ast::BinOp::NotIlike(n) => build_keyword_node(n.syntax()),
461        ast::BinOp::NotIn(n) => build_keyword_node(n.syntax()),
462        ast::BinOp::NotLike(n) => build_keyword_node(n.syntax()),
463        ast::BinOp::NotSimilarTo(n) => build_keyword_node(n.syntax()),
464        ast::BinOp::OperatorCall(op) => Doc::text(op.syntax().to_string()),
465        ast::BinOp::Or(_) => Doc::text("or"),
466        ast::BinOp::Overlaps(_) => Doc::text("overlaps"),
467        ast::BinOp::Percent(_) => Doc::text("%"),
468        ast::BinOp::Plus(_) => Doc::text("+"),
469        ast::BinOp::RAngle(_) => Doc::text(">"),
470        ast::BinOp::SimilarTo(n) => build_keyword_node(n.syntax()),
471        ast::BinOp::Slash(_) => Doc::text("/"),
472        ast::BinOp::Star(_) => Doc::text("*"),
473    }
474}
475
476fn build_literal<'a>(lit: ast::Literal) -> Doc<'a> {
477    let Some(kind) = lit.kind() else {
478        return Doc::nil();
479    };
480    match kind {
481        LitKind::Default(_) => Doc::text("default"),
482        LitKind::False(_) => Doc::text("false"),
483        LitKind::IntNumber(t) => Doc::text(t.text().to_string()),
484        LitKind::Null(_) => Doc::text("null"),
485        LitKind::NumericNumber(t) => Doc::text(t.text().to_string()),
486        LitKind::PositionalParam(t) => Doc::text(t.text().to_string()),
487        LitKind::True(_) => Doc::text("true"),
488        LitKind::BitString(_)
489        | LitKind::ByteString(_)
490        | LitKind::DollarQuotedString(_)
491        | LitKind::EscString(_)
492        | LitKind::String(_)
493        | LitKind::UnicodeEscString(_) => build_string_literal(&lit),
494    }
495}
496
497fn build_string_literal<'a>(lit: &ast::Literal) -> Doc<'a> {
498    let parts: Vec<Doc<'a>> = lit
499        .syntax()
500        .children_with_tokens()
501        .filter_map(|el| match el {
502            rowan::NodeOrToken::Token(t) if t.kind() != SyntaxKind::WHITESPACE => {
503                Some(Doc::text(format_string_token(&t)))
504            }
505            _ => None,
506        })
507        .collect();
508    Doc::list(Itertools::intersperse(parts.into_iter(), Doc::hard_line()).collect())
509}
510
511fn format_string_token(t: &SyntaxToken) -> String {
512    let text = t.text();
513    if matches!(
514        t.kind(),
515        SyntaxKind::STRING | SyntaxKind::DOLLAR_QUOTED_STRING
516    ) {
517        return text.to_string();
518    }
519    match text.find('\'') {
520        Some(idx) => {
521            let (prefix, rest) = text.split_at(idx);
522            let mut s = String::with_capacity(text.len());
523            s.push_str(&prefix.to_ascii_lowercase());
524            s.push_str(rest);
525            s
526        }
527        None => text.to_string(),
528    }
529}
530
531fn build_name<'a>(name: ast::Name) -> Doc<'a> {
532    Doc::text(quote_column_alias(&name.text()))
533}
534
535fn build_type<'a>(ty: ast::Type) -> Doc<'a> {
536    Doc::text(ty.syntax().to_string())
537}
538
539fn leading_comments_token<'a>(node: &SyntaxToken) -> Doc<'a> {
540    let mut doc = Doc::nil();
541    for next in node.siblings_with_tokens(Direction::Prev).skip(1) {
542        match next {
543            rowan::NodeOrToken::Node(_node) => {
544                break;
545            }
546            rowan::NodeOrToken::Token(token) => {
547                if token.kind() == SyntaxKind::COMMENT {
548                    doc = doc
549                        .append(Doc::text(token.text().to_string()))
550                        .append(Doc::space());
551                } else if token.kind() == SyntaxKind::WHITESPACE {
552                    continue;
553                } else {
554                    break;
555                }
556            }
557        }
558    }
559    doc
560}
561
562fn leading_comments<'a>(node: &SyntaxNode) -> Doc<'a> {
563    let mut doc = Doc::nil();
564    for next in node.siblings_with_tokens(Direction::Prev).skip(1) {
565        match next {
566            rowan::NodeOrToken::Node(_node) => {
567                break;
568            }
569            rowan::NodeOrToken::Token(token) => {
570                if token.kind() == SyntaxKind::COMMENT {
571                    let is_block = token.text().starts_with("--");
572                    doc = doc
573                        .append(Doc::text(token.text().to_string()))
574                        .append(if is_block {
575                            Doc::hard_line()
576                        } else {
577                            Doc::space()
578                        });
579                } else if token.kind() == SyntaxKind::WHITESPACE {
580                    continue;
581                } else {
582                    break;
583                }
584            }
585        }
586    }
587    doc
588}
589
590fn trailing_comments<'a>(node: &SyntaxNode) -> Doc<'a> {
591    let mut doc = Doc::nil();
592    for next in node.siblings_with_tokens(Direction::Next).skip(1) {
593        match next {
594            rowan::NodeOrToken::Node(_node) => {
595                break;
596            }
597            rowan::NodeOrToken::Token(token) => {
598                if token.kind() == SyntaxKind::COMMENT {
599                    doc = doc
600                        .append(Doc::space())
601                        .append(Doc::text(token.text().to_string()));
602                } else if token.kind() == SyntaxKind::WHITESPACE {
603                    continue;
604                } else {
605                    break;
606                }
607            }
608        }
609    }
610    doc
611}
612
613fn build_target<'a>(target: ast::Target) -> Option<Doc<'a>> {
614    let mut doc = leading_comments(target.syntax());
615
616    if target.star_token().is_some() {
617        return Some(doc.append(Doc::text("*")));
618    }
619    let expr = target.expr()?;
620    doc = doc.append(build_expr(expr));
621
622    if let Some(as_name) = target.as_name() {
623        if as_name.as_token().is_some() {
624            doc = doc.append(Doc::space()).append(Doc::text("as"))
625        }
626
627        if let Some(name) = as_name.name() {
628            doc = doc.append(Doc::space()).append(build_name(name));
629        }
630    }
631
632    doc = doc.append(trailing_comments(target.syntax()));
633
634    Some(doc)
635}
636
637pub fn fmt(text: &str) -> String {
638    let parse = ast::SourceFile::parse(text);
639    let file = parse.tree();
640    println!("{text}");
641    println!("---");
642    println!("{:#?}", file.syntax());
643    println!("---");
644    debug_assert_eq!(
645        parse.errors(),
646        vec![],
647        "should bail out when there's parse errors"
648    );
649    let doc = build_source_file(&file);
650    print(&doc, &PrintOptions::default())
651}