1use rustc_hash::FxHashSet;
31
32use rowan::{Direction, NodeOrToken, TextRange};
33use salsa::Database as Db;
34use squawk_syntax::SyntaxKind;
35use squawk_syntax::ast::{self, AstNode, AstToken};
36
37use crate::db::{File, parse};
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum FoldKind {
41 ArgList,
42 Array,
43 Comment,
44 FunctionCall,
45 Join,
46 List,
47 Statement,
48 Subquery,
49 Tuple,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
53pub struct Fold {
54 pub range: TextRange,
55 pub kind: FoldKind,
56}
57
58#[salsa::tracked]
59pub fn folding_ranges(db: &dyn Db, file: File) -> Vec<Fold> {
60 let parse = parse(db, file);
61
62 let mut folds = vec![];
63 let mut visited_comments = FxHashSet::default();
64
65 for element in parse.tree().syntax().descendants_with_tokens() {
66 match &element {
67 NodeOrToken::Token(token) => {
68 if let Some(comment) = ast::Comment::cast(token.clone())
69 && !visited_comments.contains(&comment)
70 && let Some(range) =
71 contiguous_range_for_comment(comment, &mut visited_comments)
72 {
73 folds.push(Fold {
74 range,
75 kind: FoldKind::Comment,
76 });
77 }
78 }
79 NodeOrToken::Node(node) => {
80 if let Some(kind) = fold_kind(node.kind()) {
81 if !node.text().contains_char('\n') {
82 continue;
83 }
84 let start = node
86 .children_with_tokens()
87 .find(|e| match e {
88 NodeOrToken::Token(t) => {
89 let kind = t.kind();
90 kind != SyntaxKind::COMMENT && kind != SyntaxKind::WHITESPACE
91 }
92 NodeOrToken::Node(_) => true,
93 })
94 .map(|e| e.text_range().start())
95 .unwrap_or_else(|| node.text_range().start());
96 folds.push(Fold {
97 range: TextRange::new(start, node.text_range().end()),
98 kind,
99 });
100 }
101 }
102 }
103 }
104
105 folds
106}
107
108fn fold_kind(kind: SyntaxKind) -> Option<FoldKind> {
109 if ast::Stmt::can_cast(kind) {
110 return Some(FoldKind::Statement);
111 }
112
113 match kind {
114 SyntaxKind::ARG_LIST | SyntaxKind::TABLE_ARG_LIST | SyntaxKind::PARAM_LIST => {
115 Some(FoldKind::ArgList)
116 }
117 SyntaxKind::ARRAY_EXPR => Some(FoldKind::Array),
118 SyntaxKind::CALL_EXPR => Some(FoldKind::FunctionCall),
119 SyntaxKind::JOIN => Some(FoldKind::Join),
120 SyntaxKind::PAREN_SELECT => Some(FoldKind::Subquery),
121 SyntaxKind::TUPLE_EXPR => Some(FoldKind::Tuple),
122 SyntaxKind::WHEN_CLAUSE_LIST
123 | SyntaxKind::ALTER_OPTION_LIST
124 | SyntaxKind::ATTRIBUTE_LIST
125 | SyntaxKind::BEGIN_FUNC_OPTION_LIST
126 | SyntaxKind::COLUMN_LIST
127 | SyntaxKind::CONFLICT_INDEX_ITEM_LIST
128 | SyntaxKind::CONSTRAINT_EXCLUSION_LIST
129 | SyntaxKind::COPY_OPTION_LIST
130 | SyntaxKind::DATABASE_OPTION_LIST
131 | SyntaxKind::DROP_OP_CLASS_OPTION_LIST
132 | SyntaxKind::FDW_OPTION_LIST
133 | SyntaxKind::FUNCTION_SIG_LIST
134 | SyntaxKind::FUNC_OPTION_LIST
135 | SyntaxKind::GROUP_BY_LIST
136 | SyntaxKind::JSON_TABLE_COLUMN_LIST
137 | SyntaxKind::OPERATOR_CLASS_OPTION_LIST
138 | SyntaxKind::OPTION_ITEM_LIST
139 | SyntaxKind::OP_SIG_LIST
140 | SyntaxKind::PARTITION_ITEM_LIST
141 | SyntaxKind::PARTITION_LIST
142 | SyntaxKind::PATH_LIST
143 | SyntaxKind::RETURNING_OPTION_LIST
144 | SyntaxKind::REVOKE_COMMAND_LIST
145 | SyntaxKind::ROLE_OPTION_LIST
146 | SyntaxKind::ROLE_REF_LIST
147 | SyntaxKind::ROW_LIST
148 | SyntaxKind::RULE_STMT_LIST
149 | SyntaxKind::SEQUENCE_OPTION_LIST
150 | SyntaxKind::SET_COLUMN_LIST
151 | SyntaxKind::SET_EXPR_LIST
152 | SyntaxKind::SET_OPTIONS_LIST
153 | SyntaxKind::SORT_BY_LIST
154 | SyntaxKind::TABLE_AND_COLUMNS_LIST
155 | SyntaxKind::TABLE_LIST
156 | SyntaxKind::TARGET_LIST
157 | SyntaxKind::TRANSACTION_MODE_LIST
158 | SyntaxKind::TRIGGER_EVENT_LIST
159 | SyntaxKind::VACUUM_OPTION_LIST
160 | SyntaxKind::VARIANT_LIST
161 | SyntaxKind::EXPR_AS_NAME_LIST
162 | SyntaxKind::XML_COLUMN_OPTION_LIST
163 | SyntaxKind::XML_NAMESPACE_LIST
164 | SyntaxKind::XML_TABLE_COLUMN_LIST
165 | SyntaxKind::LABEL_AND_PROPERTIES_LIST
166 | SyntaxKind::PATH_PATTERN_LIST => Some(FoldKind::List),
167 _ => None,
168 }
169}
170
171fn contiguous_range_for_comment(
172 first: ast::Comment,
173 visited: &mut FxHashSet<ast::Comment>,
174) -> Option<TextRange> {
175 visited.insert(first.clone());
176
177 let group_kind = first.kind();
179 if !group_kind.is_line() {
180 return None;
181 }
182
183 let mut last = first.clone();
184 for element in first.syntax().siblings_with_tokens(Direction::Next) {
185 match element {
186 NodeOrToken::Token(token) => {
187 if let Some(ws) = ast::Whitespace::cast(token.clone())
188 && !ws.spans_multiple_lines()
189 {
190 continue;
192 }
193 if let Some(c) = ast::Comment::cast(token) {
194 visited.insert(c.clone());
195 last = c;
196 continue;
197 }
198 break;
202 }
203 NodeOrToken::Node(_) => break,
204 }
205 }
206
207 if first != last {
208 Some(TextRange::new(
209 first.syntax().text_range().start(),
210 last.syntax().text_range().end(),
211 ))
212 } else {
213 None
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use insta::assert_snapshot;
221
222 use crate::db::{Database, File};
223
224 use super::*;
225
226 fn fold_kind_str(kind: &FoldKind) -> &'static str {
227 match kind {
228 FoldKind::ArgList => "arglist",
229 FoldKind::Array => "array",
230 FoldKind::Comment => "comment",
231 FoldKind::FunctionCall => "function_call",
232 FoldKind::Join => "join",
233 FoldKind::List => "list",
234 FoldKind::Statement => "statement",
235 FoldKind::Subquery => "subquery",
236 FoldKind::Tuple => "tuple",
237 }
238 }
239
240 #[must_use]
241 fn check(sql: &str) -> String {
242 let db = Database::default();
243 let file = File::new(&db, sql.to_string().into());
244 let folds = folding_ranges(&db, file);
245
246 if folds.is_empty() {
247 return sql.to_string();
248 }
249
250 #[derive(PartialEq, Eq, PartialOrd, Ord)]
251 struct Event<'a> {
252 offset: usize,
253 is_end: bool,
254 kind: &'a str,
255 }
256
257 let mut events: Vec<Event<'_>> = vec![];
258 for fold in &folds {
259 let start: usize = fold.range.start().into();
260 let end: usize = fold.range.end().into();
261 let kind = fold_kind_str(&fold.kind);
262 events.push(Event {
263 offset: start,
264 is_end: false,
265 kind,
266 });
267 events.push(Event {
268 offset: end,
269 is_end: true,
270 kind,
271 });
272 }
273 events.sort();
274
275 let mut output = String::new();
276 let mut pos = 0usize;
277 for event in &events {
278 if event.offset > pos {
279 output.push_str(&sql[pos..event.offset]);
280 pos = event.offset;
281 }
282 if event.is_end {
283 output.push_str("</fold>");
284 } else {
285 output.push_str(&format!("<fold {}>", event.kind));
286 }
287 }
288 if pos < sql.len() {
289 output.push_str(&sql[pos..]);
290 }
291 output
292 }
293
294 #[test]
295 fn fold_create_table() {
296 assert_snapshot!(check("
297create table t (
298 id int,
299 name text
300);"), @"
301 <fold statement>create table t <fold arglist>(
302 id int,
303 name text
304 )</fold>;</fold>
305 ");
306 }
307
308 #[test]
309 fn fold_select() {
310 assert_snapshot!(check("
311select
312 id,
313 name
314from t;"), @"
315 <fold statement>select
316 <fold list>id,
317 name</fold>
318 from t;</fold>
319 ");
320 }
321
322 #[test]
323 fn do_not_fold_single_line_comment() {
324 assert_snapshot!(check("
325-- a comment
326select 1;"), @"
327 -- a comment
328 select 1;
329 ");
330 }
331
332 #[test]
333 fn fold_comments_does_not_apply_when_diff_comment_types() {
334 assert_snapshot!(check("
335/* first part */
336-- second part
337select 1;"), @"
338 /* first part */
339 -- second part
340 select 1;
341 ");
342 }
343
344 #[test]
345 fn fold_comments_and_multi_statements() {
346 assert_snapshot!(check("
347-- this is
348
349-- a comment
350-- with some more
351select a, b, 3
352 from t
353 where c > 10;"), @"
354 -- this is
355
356 <fold comment>-- a comment
357 -- with some more</fold>
358 <fold statement>select a, b, 3
359 from t
360 where c > 10;</fold>
361 ");
362 }
363
364 #[test]
365 fn fold_comments_does_not_apply_when_whitespace_between() {
366 assert_snapshot!(check("
367-- this is
368
369-- a comment
370-- with some more
371select 1;"), @"
372 -- this is
373
374 <fold comment>-- a comment
375 -- with some more</fold>
376 select 1;
377 ");
378 }
379
380 #[test]
381 fn fold_multiline_comments() {
382 assert_snapshot!(check("
383-- this is
384-- a comment
385select 1;"), @"
386 <fold comment>-- this is
387 -- a comment</fold>
388 select 1;
389 ");
390 }
391
392 #[test]
393 fn fold_single_line_no_fold() {
394 assert_snapshot!(check("select 1;"), @"select 1;");
395 }
396
397 #[test]
398 fn fold_subquery() {
399 assert_snapshot!(check("
400select * from (
401 select id from t
402);"), @"
403 <fold statement>select * from <fold statement>(
404 select id from t
405 )</fold>;</fold>
406 ");
407 }
408
409 #[test]
410 fn fold_case_when() {
411 assert_snapshot!(check("
412select
413 case
414 when x = 1 then 'a'
415 when x = 2 then 'b'
416 end
417from t;"), @"
418 <fold statement>select
419 <fold list>case
420 <fold list>when x = 1 then 'a'
421 when x = 2 then 'b'</fold>
422 end</fold>
423 from t;</fold>
424 ");
425 }
426
427 #[test]
428 fn fold_join() {
429 assert_snapshot!(check("
430select *
431from a
432join b
433 on a.id = b.id;"), @"
434 <fold statement>select *
435 from a
436 <fold join>join b
437 on a.id = b.id</fold>;</fold>
438 ");
439 }
440
441 #[test]
442 fn fold_array_literal() {
443 assert_snapshot!(check("
444select * from t where
445 x = any(array[
446 1,
447 2,
448 3
449 ]);"), @"
450 <fold statement>select * from t where
451 x = <fold function_call>any(<fold array>array[
452 1,
453 2,
454 3
455 ]</fold>)</fold>;</fold>
456 ");
457 }
458
459 #[test]
460 fn fold_tuple_literal() {
461 assert_snapshot!(check("
462select (
463 1,
464 2,
465 3
466);"), @"
467 <fold statement>select <fold list><fold tuple>(
468 1,
469 2,
470 3
471 )</fold></fold>;</fold>
472 ");
473 }
474
475 #[test]
476 fn fold_tuple_bin_expr() {
477 assert_snapshot!(check("
478select * from x
479 where z in (
480 1,
481 2,
482 3,
483 4,
484 5
485 );
486"), @"
487 <fold statement>select * from x
488 where z in <fold tuple>(
489 1,
490 2,
491 3,
492 4,
493 5
494 )</fold>;</fold>
495 ");
496 }
497
498 #[test]
499 fn fold_function_call() {
500 assert_snapshot!(check("
501select coalesce(
502 a,
503 b,
504 c
505);"), @"
506 <fold statement>select <fold function_call><fold list>coalesce<fold arglist>(
507 a,
508 b,
509 c
510 )</fold></fold></fold>;</fold>
511 ");
512 }
513
514 #[test]
515 fn fold_create_enum() {
516 assert_snapshot!(check("
517create type status as enum (
518 'active',
519 'inactive'
520);"), @"
521 <fold statement>create type status as enum <fold list>(
522 'active',
523 'inactive'
524 )</fold>;</fold>
525 ");
526 }
527
528 #[test]
529 fn fold_insert_values() {
530 assert_snapshot!(check("
531insert into t (id, name)
532values
533 (1, 'a'),
534 (2, 'b');"), @"
535 <fold statement>insert into t (id, name)
536 <fold statement>values
537 <fold list>(1, 'a'),
538 (2, 'b')</fold></fold>;</fold>
539 ");
540 }
541
542 #[test]
543 fn no_fold_single_line_create_table() {
544 assert_snapshot!(check("create table t (id int);"), @"create table t (id int);");
545 }
546
547 #[test]
548 fn list_variants() {
549 let unhandled_list_kinds: Vec<SyntaxKind> = (0..SyntaxKind::__LAST as u16)
550 .map(SyntaxKind::from)
551 .filter(|kind| format!("{kind:?}").ends_with("_LIST"))
552 .filter(|kind| fold_kind(*kind).is_none())
553 .collect();
554
555 assert_eq!(
556 unhandled_list_kinds,
557 vec![],
558 "All _LIST SyntaxKind variants should be handled in fold_kind"
559 );
560 }
561}