Skip to main content

squawk_fmt/
fmt.rs

1use itertools::Itertools;
2use rowan::Direction;
3use squawk_syntax::ast::{self, AstNode};
4use squawk_syntax::{SyntaxKind, SyntaxNode, SyntaxToken};
5use tiny_pretty::Doc;
6use tiny_pretty::{PrintOptions, print};
7
8// TODO: anytime we have `syntax().to_string()`, it means we have to do more to
9// actually convert the data into the IR. to_string() is a temp hack
10
11fn build_source_file(source_file: &ast::SourceFile) -> Doc<'_> {
12    let mut doc = Doc::nil();
13    for el in source_file.syntax().children_with_tokens() {
14        match el {
15            rowan::NodeOrToken::Node(node) => {
16                if let Some(stmt) = ast::Stmt::cast(node) {
17                    match stmt {
18                        ast::Stmt::Select(select) => {
19                            doc = doc.append(build_select_doc(&select));
20                        }
21                        ast::Stmt::CreateTable(create_table) => {
22                            doc = doc.append(build_create_table(&create_table));
23                        }
24                        _ => (),
25                    }
26                }
27            }
28            rowan::NodeOrToken::Token(token) => {
29                if token.kind() == SyntaxKind::COMMENT {
30                    doc = doc.append(Doc::text(token.text().to_string()));
31                } else if token.kind() == SyntaxKind::WHITESPACE {
32                    // TODO: I think we can improve this
33                    let lines = token.text().lines().count();
34                    if lines >= 2 {
35                        doc = doc.append(Doc::empty_line()).append(Doc::empty_line());
36                    } else {
37                        doc = doc.append(Doc::empty_line());
38                    }
39                }
40            }
41        }
42    }
43    doc
44}
45
46fn build_create_table<'a>(create_table: &ast::CreateTable) -> Doc<'a> {
47    let mut doc = Doc::text("create")
48        .append(Doc::space())
49        .append(Doc::text("table"))
50        .append(Doc::space())
51        .append(Doc::text(
52            create_table.path().map(|x| x.syntax().to_string()).unwrap(),
53        ))
54        .append(Doc::text("("))
55        .append(
56            Doc::line_or_nil()
57                .append(Doc::list(
58                    Itertools::intersperse(
59                        create_table
60                            .table_arg_list()
61                            .unwrap()
62                            .args()
63                            .map(build_table_arg),
64                        Doc::text(",").append(Doc::hard_line()),
65                    )
66                    .collect(),
67                ))
68                .nest(2)
69                .append(Doc::line_or_nil())
70                .group(),
71        )
72        .append(Doc::text(")"));
73
74    doc = doc.append(build_semicolon(create_table.semicolon_token()));
75
76    doc
77}
78
79fn build_table_arg<'a>(create_table: ast::TableArg) -> Doc<'a> {
80    match create_table {
81        ast::TableArg::Column(column) => Doc::text(column.name().unwrap().syntax().to_string())
82            .append(Doc::space())
83            .append(Doc::text(column.ty().unwrap().syntax().to_string())),
84        ast::TableArg::LikeClause(_like_clause) => todo!(),
85        ast::TableArg::TableConstraint(_table_constraint) => todo!(),
86    }
87}
88
89fn build_select_doc<'a>(select: &ast::Select) -> Doc<'a> {
90    let mut doc = Doc::text("select").append(Doc::line_or_space());
91
92    if let Some(select_clause) = select.select_clause() {
93        if let Some(distinct_clause) = select_clause.distinct_clause() {
94            doc = doc.append(leading_comments(distinct_clause.syntax()));
95            doc = doc.append(Doc::text("distinct")).append(Doc::space());
96        }
97        if let Some(all_token) = select_clause.all_token() {
98            doc = doc.append(leading_comments_token(&all_token));
99            doc = doc.append(Doc::text("all")).append(Doc::space());
100        }
101        if let Some(target_list) = select_clause.target_list() {
102            doc = doc.append(leading_comments(target_list.syntax()));
103            doc = doc
104                .append(Doc::list(
105                    Itertools::intersperse(
106                        target_list.targets().flat_map(build_target),
107                        Doc::text(",").append(Doc::line_or_space()),
108                    )
109                    .collect(),
110                ))
111                .nest(2);
112        }
113    }
114
115    if let Some(from) = &select.from_clause() {
116        doc = doc.append(
117            Doc::line_or_space()
118                .append(Doc::text("from"))
119                .append(Doc::space())
120                .append(Doc::text(
121                    from.from_items().next().unwrap().syntax().to_string(),
122                )),
123        );
124    }
125
126    if let Some(group) = &select.group_by_clause() {
127        doc = doc.append(
128            Doc::line_or_space()
129                .append(Doc::text("group by"))
130                .append(Doc::space())
131                .append(Doc::text(
132                    group.group_by_list().unwrap().syntax().to_string(),
133                )),
134        );
135    }
136
137    doc = doc.append(build_semicolon(select.semicolon_token()));
138
139    doc.group()
140}
141
142fn build_semicolon<'a>(semi: Option<SyntaxToken>) -> Doc<'a> {
143    let Some(semi) = semi else {
144        return Doc::nil();
145    };
146    let mut doc = Doc::nil();
147    let mut comments: Vec<SyntaxToken> = vec![];
148    for next in semi.siblings_with_tokens(Direction::Prev).skip(1) {
149        match next {
150            rowan::NodeOrToken::Node(_) => break,
151            rowan::NodeOrToken::Token(token) => {
152                if token.kind() == SyntaxKind::COMMENT {
153                    comments.push(token);
154                } else if token.kind() == SyntaxKind::WHITESPACE {
155                    continue;
156                } else {
157                    break;
158                }
159            }
160        }
161    }
162    for comment in comments.iter().rev() {
163        doc = doc.append(Doc::text(comment.text().to_string()));
164    }
165    doc.append(Doc::text(";"))
166}
167
168fn build_expr<'a>(expr: ast::Expr) -> Doc<'a> {
169    match expr {
170        ast::Expr::ArrayExpr(array_expr) => {
171            let mut doc = Doc::nil();
172
173            // nested parts of array expressions don't require the array token
174            if array_expr.array_token().is_some() {
175                doc = doc.append(Doc::text("array"));
176            };
177
178            if let Some(select) = array_expr.select() {
179                doc = doc
180                    .append(Doc::text("("))
181                    .append(build_select_doc(&select))
182                    .append(Doc::text(")"))
183            } else {
184                doc = doc
185                    .append(Doc::text("["))
186                    .append(Doc::list(
187                        Itertools::intersperse(
188                            array_expr.exprs().map(build_expr),
189                            Doc::text(",").append(Doc::space()),
190                        )
191                        .collect(),
192                    ))
193                    .append(Doc::text("]"));
194            }
195
196            doc
197        }
198        ast::Expr::BetweenExpr(between_expr) => {
199            let mut doc = build_expr(between_expr.target().unwrap());
200            if between_expr.not_token().is_some() {
201                doc = doc.append(Doc::space()).append(Doc::text("not"));
202            }
203            doc = doc.append(Doc::space()).append(Doc::text("between"));
204            if between_expr.symmetric_token().is_some() {
205                doc = doc.append(Doc::space()).append(Doc::text("symmetric"));
206            }
207            doc.append(Doc::space())
208                .append(build_expr(between_expr.start().unwrap()))
209                .append(Doc::space())
210                .append(Doc::text("and"))
211                .append(Doc::space())
212                .append(build_expr(between_expr.end().unwrap()))
213        }
214        ast::Expr::BinExpr(bin_expr) => build_expr(bin_expr.lhs().unwrap())
215            .append(Doc::space())
216            .append(build_op(bin_expr.op().unwrap()))
217            .append(Doc::space())
218            .append(build_expr(bin_expr.rhs().unwrap())),
219        // ast::Expr::CallExpr(call_expr) => todo!(),
220        // ast::Expr::CaseExpr(case_expr) => todo!(),
221        ast::Expr::CastExpr(cast_expr) => {
222            let mut doc = Doc::nil();
223            if cast_expr.colon_colon().is_some() {
224                doc = doc
225                    .append(build_expr(cast_expr.expr().unwrap()))
226                    .append(Doc::text("::"))
227                    .append(build_type(cast_expr.ty().unwrap()))
228            } else if cast_expr.as_token().is_some() {
229                if cast_expr.cast_token().is_some() {
230                    doc = doc.append(Doc::text("cast"))
231                } else if cast_expr.treat_token().is_some() {
232                    doc = doc.append(Doc::text("treat"))
233                }
234                doc = doc
235                    .append(Doc::text("("))
236                    .append(build_expr(cast_expr.expr().unwrap()))
237                    .append(Doc::space())
238                    .append(Doc::text("as"))
239                    .append(Doc::space())
240                    .append(build_type(cast_expr.ty().unwrap()))
241                    .append(Doc::text(")"))
242            } else {
243                doc = doc
244                    .append(build_type(cast_expr.ty().unwrap()))
245                    .append(Doc::space())
246                    .append(build_literal(cast_expr.literal().unwrap()))
247            }
248            doc
249        }
250        // ast::Expr::FieldExpr(field_expr) => todo!(),
251        // ast::Expr::IndexExpr(index_expr) => todo!(),
252        // ast::Expr::Literal(literal) => todo!(),
253        // ast::Expr::NameRef(name_ref) => todo!(),
254        // ast::Expr::ParenExpr(paren_expr) => todo!(),
255        ast::Expr::PostfixExpr(postfix_expr) => {
256            let expr = build_expr(postfix_expr.expr().unwrap());
257            let op = match postfix_expr.op().unwrap() {
258                ast::PostfixOp::AtLocal(_) => Doc::text("at local"),
259                ast::PostfixOp::IsNull(_) => Doc::text("isnull"),
260                ast::PostfixOp::NotNull(_) => Doc::text("notnull"),
261                ast::PostfixOp::IsJson(n) => {
262                    let mut doc = Doc::text("is json");
263                    if let Some(clause) = n.json_keys_unique_clause() {
264                        doc = doc
265                            .append(Doc::space())
266                            .append(build_json_keys_unique_clause(clause));
267                    }
268                    doc
269                }
270                ast::PostfixOp::IsJsonArray(n) => {
271                    let mut doc = Doc::text("is json array");
272                    if let Some(clause) = n.json_keys_unique_clause() {
273                        doc = doc
274                            .append(Doc::space())
275                            .append(build_json_keys_unique_clause(clause));
276                    }
277                    doc
278                }
279                ast::PostfixOp::IsJsonObject(n) => {
280                    let mut doc = Doc::text("is json object");
281                    if let Some(clause) = n.json_keys_unique_clause() {
282                        doc = doc
283                            .append(Doc::space())
284                            .append(build_json_keys_unique_clause(clause));
285                    }
286                    doc
287                }
288                ast::PostfixOp::IsJsonScalar(n) => {
289                    let mut doc = Doc::text("is json scalar");
290                    if let Some(clause) = n.json_keys_unique_clause() {
291                        doc = doc
292                            .append(Doc::space())
293                            .append(build_json_keys_unique_clause(clause));
294                    }
295                    doc
296                }
297                ast::PostfixOp::IsJsonValue(n) => {
298                    let mut doc = Doc::text("is json value");
299                    if let Some(clause) = n.json_keys_unique_clause() {
300                        doc = doc
301                            .append(Doc::space())
302                            .append(build_json_keys_unique_clause(clause));
303                    }
304                    doc
305                }
306                ast::PostfixOp::IsNormalized(n) => {
307                    let mut doc = Doc::text("is");
308                    if let Some(form) = n.unicode_normal_form() {
309                        doc = doc
310                            .append(Doc::space())
311                            .append(build_unicode_normal_form(form));
312                    }
313                    doc.append(Doc::space()).append(Doc::text("normalized"))
314                }
315                ast::PostfixOp::IsNotJson(n) => {
316                    let mut doc = Doc::text("is not json");
317                    if let Some(clause) = n.json_keys_unique_clause() {
318                        doc = doc
319                            .append(Doc::space())
320                            .append(build_json_keys_unique_clause(clause));
321                    }
322                    doc
323                }
324                ast::PostfixOp::IsNotJsonArray(n) => {
325                    let mut doc = Doc::text("is not json array");
326                    if let Some(clause) = n.json_keys_unique_clause() {
327                        doc = doc
328                            .append(Doc::space())
329                            .append(build_json_keys_unique_clause(clause));
330                    }
331                    doc
332                }
333                ast::PostfixOp::IsNotJsonObject(n) => {
334                    let mut doc = Doc::text("is not json object");
335                    if let Some(clause) = n.json_keys_unique_clause() {
336                        doc = doc
337                            .append(Doc::space())
338                            .append(build_json_keys_unique_clause(clause));
339                    }
340                    doc
341                }
342                ast::PostfixOp::IsNotJsonScalar(n) => {
343                    let mut doc = Doc::text("is not json scalar");
344                    if let Some(clause) = n.json_keys_unique_clause() {
345                        doc = doc
346                            .append(Doc::space())
347                            .append(build_json_keys_unique_clause(clause));
348                    }
349                    doc
350                }
351                ast::PostfixOp::IsNotJsonValue(n) => {
352                    let mut doc = Doc::text("is not json value");
353                    if let Some(clause) = n.json_keys_unique_clause() {
354                        doc = doc
355                            .append(Doc::space())
356                            .append(build_json_keys_unique_clause(clause));
357                    }
358                    doc
359                }
360                ast::PostfixOp::IsNotNormalized(n) => {
361                    let mut doc = Doc::text("is not");
362                    if let Some(form) = n.unicode_normal_form() {
363                        doc = doc
364                            .append(Doc::space())
365                            .append(build_unicode_normal_form(form));
366                    }
367                    doc.append(Doc::space()).append(Doc::text("normalized"))
368                }
369            };
370            expr.append(Doc::space()).append(op)
371        }
372        // ast::Expr::PrefixExpr(prefix_expr) => todo!(),
373        // ast::Expr::SliceExpr(slice_expr) => todo!(),
374        // ast::Expr::TupleExpr(tuple_expr) => todo!(),
375        _ => Doc::text(expr.syntax().to_string()),
376    }
377}
378
379fn build_json_keys_unique_clause<'a>(clause: ast::JsonKeysUniqueClause) -> Doc<'a> {
380    let prefix = if clause.with_token().is_some() {
381        "with"
382    } else {
383        "without"
384    };
385    Doc::text(prefix)
386        .append(Doc::space())
387        .append(Doc::text("unique"))
388        .append(Doc::space())
389        .append(Doc::text("keys"))
390}
391
392fn build_unicode_normal_form<'a>(form: ast::UnicodeNormalForm) -> Doc<'a> {
393    if form.nfc_token().is_some() {
394        Doc::text("nfc")
395    } else if form.nfd_token().is_some() {
396        Doc::text("nfd")
397    } else if form.nfkc_token().is_some() {
398        Doc::text("nfkc")
399    } else {
400        Doc::text("nfkd")
401    }
402}
403
404fn build_op<'a>(op: ast::BinOp) -> Doc<'a> {
405    match op {
406        ast::BinOp::And(_) => todo!(),
407        ast::BinOp::AtTimeZone(_) => todo!(),
408        ast::BinOp::Caret(_) => todo!(),
409        ast::BinOp::Collate(_) => todo!(),
410        ast::BinOp::ColonColon(_) => todo!(),
411        ast::BinOp::ColonEq(_) => todo!(),
412        ast::BinOp::CustomOp(custom_op) => Doc::text(custom_op.syntax().to_string()),
413        ast::BinOp::Eq(_) => todo!(),
414        ast::BinOp::FatArrow(_) => todo!(),
415        ast::BinOp::Gteq(_) => todo!(),
416        ast::BinOp::Ilike(_) => todo!(),
417        ast::BinOp::In(_) => todo!(),
418        ast::BinOp::Is(_) => todo!(),
419        ast::BinOp::IsDistinctFrom(_) => todo!(),
420        ast::BinOp::IsNot(_) => todo!(),
421        ast::BinOp::IsNotDistinctFrom(_) => todo!(),
422        ast::BinOp::LAngle(_) => todo!(),
423        ast::BinOp::Like(_) => todo!(),
424        ast::BinOp::Lteq(_) => todo!(),
425        ast::BinOp::Minus(_) => todo!(),
426        ast::BinOp::Neq(_) => todo!(),
427        ast::BinOp::Neqb(_) => todo!(),
428        ast::BinOp::NotIlike(_) => todo!(),
429        ast::BinOp::NotIn(_) => todo!(),
430        ast::BinOp::NotLike(_) => todo!(),
431        ast::BinOp::NotSimilarTo(_) => todo!(),
432        ast::BinOp::OperatorCall(_) => todo!(),
433        ast::BinOp::Or(_) => todo!(),
434        ast::BinOp::Overlaps(_) => todo!(),
435        ast::BinOp::Percent(_) => todo!(),
436        ast::BinOp::Plus(_) => Doc::text("+"),
437        ast::BinOp::RAngle(_) => todo!(),
438        ast::BinOp::SimilarTo(_) => todo!(),
439        ast::BinOp::Slash(_) => todo!(),
440        ast::BinOp::Star(_) => todo!(),
441    }
442}
443
444fn build_literal<'a>(lit: ast::Literal) -> Doc<'a> {
445    Doc::text(lit.syntax().to_string())
446}
447
448fn build_type<'a>(ty: ast::Type) -> Doc<'a> {
449    Doc::text(ty.syntax().to_string())
450}
451
452fn leading_comments_token<'a>(node: &SyntaxToken) -> Doc<'a> {
453    let mut doc = Doc::nil();
454    for next in node.siblings_with_tokens(Direction::Prev).skip(1) {
455        match next {
456            rowan::NodeOrToken::Node(_node) => {
457                break;
458            }
459            rowan::NodeOrToken::Token(token) => {
460                if token.kind() == SyntaxKind::COMMENT {
461                    doc = doc
462                        .append(Doc::text(token.text().to_string()))
463                        .append(Doc::space());
464                } else if token.kind() == SyntaxKind::WHITESPACE {
465                    continue;
466                } else {
467                    break;
468                }
469            }
470        }
471    }
472    doc
473}
474
475fn leading_comments<'a>(node: &SyntaxNode) -> Doc<'a> {
476    let mut doc = Doc::nil();
477    for next in node.siblings_with_tokens(Direction::Prev).skip(1) {
478        match next {
479            rowan::NodeOrToken::Node(_node) => {
480                break;
481            }
482            rowan::NodeOrToken::Token(token) => {
483                if token.kind() == SyntaxKind::COMMENT {
484                    let is_block = token.text().starts_with("--");
485                    doc = doc
486                        .append(Doc::text(token.text().to_string()))
487                        .append(if is_block {
488                            Doc::hard_line()
489                        } else {
490                            Doc::space()
491                        });
492                } else if token.kind() == SyntaxKind::WHITESPACE {
493                    continue;
494                } else {
495                    break;
496                }
497            }
498        }
499    }
500    doc
501}
502
503fn trailing_comments<'a>(node: &SyntaxNode) -> Doc<'a> {
504    let mut doc = Doc::nil();
505    for next in node.siblings_with_tokens(Direction::Next).skip(1) {
506        match next {
507            rowan::NodeOrToken::Node(_node) => {
508                break;
509            }
510            rowan::NodeOrToken::Token(token) => {
511                if token.kind() == SyntaxKind::COMMENT {
512                    doc = doc
513                        .append(Doc::space())
514                        .append(Doc::text(token.text().to_string()));
515                } else if token.kind() == SyntaxKind::WHITESPACE {
516                    continue;
517                } else {
518                    break;
519                }
520            }
521        }
522    }
523    doc
524}
525
526fn build_target<'a>(target: ast::Target) -> Option<Doc<'a>> {
527    let mut doc = leading_comments(target.syntax());
528
529    if target.star_token().is_some() {
530        return Some(doc.append(Doc::text("*")));
531    }
532    let expr = target.expr()?;
533    doc = doc.append(build_expr(expr));
534
535    if let Some(as_name) = target.as_name() {
536        if as_name.as_token().is_some() {
537            doc = doc.append(Doc::space()).append(Doc::text("as"))
538        }
539
540        if let Some(name) = as_name.name() {
541            // TODO: quoting or not?
542            doc = doc
543                .append(Doc::space())
544                .append(Doc::text(name.syntax().to_string()));
545        }
546    }
547
548    doc = doc.append(trailing_comments(target.syntax()));
549
550    Some(doc)
551}
552
553pub fn fmt(text: &str) -> String {
554    let parse = ast::SourceFile::parse(text);
555    let file = parse.tree();
556    println!("{text}");
557    println!("---");
558    println!("{:#?}", file.syntax());
559    println!("---");
560    debug_assert_eq!(
561        parse.errors(),
562        vec![],
563        "should bail out when there's parse errors"
564    );
565    let doc = build_source_file(&file);
566    print(&doc, &PrintOptions::default())
567}