1use rustc_hash::FxHashSet;
31
32use rowan::{Direction, NodeOrToken, TextRange};
33use salsa::Database as Db;
34use squawk_syntax::SyntaxKind;
35use squawk_syntax::ast::{self, AstNode, AstToken};
36
37use crate::db::{File, parse};
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum FoldKind {
41 ArgList,
42 Array,
43 Comment,
44 FunctionCall,
45 Join,
46 List,
47 Statement,
48 Subquery,
49 Tuple,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
53pub struct Fold {
54 pub range: TextRange,
55 pub kind: FoldKind,
56}
57
58#[salsa::tracked]
59pub fn folding_ranges(db: &dyn Db, file: File) -> Vec<Fold> {
60 let parse = parse(db, file);
61
62 let mut folds = vec![];
63 let mut visited_comments = FxHashSet::default();
64
65 for element in parse.tree().syntax().descendants_with_tokens() {
66 match &element {
67 NodeOrToken::Token(token) => {
68 if let Some(comment) = ast::Comment::cast(token.clone())
69 && !visited_comments.contains(&comment)
70 && let Some(range) =
71 contiguous_range_for_comment(comment, &mut visited_comments)
72 {
73 folds.push(Fold {
74 range,
75 kind: FoldKind::Comment,
76 });
77 }
78 }
79 NodeOrToken::Node(node) => {
80 if let Some(kind) = fold_kind(node.kind()) {
81 if !node.text().contains_char('\n') {
82 continue;
83 }
84 let start = node
86 .children_with_tokens()
87 .find(|e| match e {
88 NodeOrToken::Token(t) => {
89 let kind = t.kind();
90 kind != SyntaxKind::COMMENT && kind != SyntaxKind::WHITESPACE
91 }
92 NodeOrToken::Node(_) => true,
93 })
94 .map(|e| e.text_range().start())
95 .unwrap_or_else(|| node.text_range().start());
96 folds.push(Fold {
97 range: TextRange::new(start, node.text_range().end()),
98 kind,
99 });
100 }
101 }
102 }
103 }
104
105 folds
106}
107
108fn fold_kind(kind: SyntaxKind) -> Option<FoldKind> {
109 if ast::Stmt::can_cast(kind) {
110 return Some(FoldKind::Statement);
111 }
112
113 match kind {
114 SyntaxKind::ARG_LIST | SyntaxKind::TABLE_ARG_LIST | SyntaxKind::PARAM_LIST => {
115 Some(FoldKind::ArgList)
116 }
117 SyntaxKind::ARRAY_EXPR => Some(FoldKind::Array),
118 SyntaxKind::CALL_EXPR => Some(FoldKind::FunctionCall),
119 SyntaxKind::JOIN => Some(FoldKind::Join),
120 SyntaxKind::PAREN_SELECT => Some(FoldKind::Subquery),
121 SyntaxKind::TUPLE_EXPR => Some(FoldKind::Tuple),
122 SyntaxKind::WHEN_CLAUSE_LIST
123 | SyntaxKind::ALTER_OPTION_LIST
124 | SyntaxKind::ATTRIBUTE_LIST
125 | SyntaxKind::BEGIN_FUNC_OPTION_LIST
126 | SyntaxKind::CHECKPOINT_OPTION_LIST
127 | SyntaxKind::COLUMN_LIST
128 | SyntaxKind::CONFLICT_INDEX_ITEM_LIST
129 | SyntaxKind::CONSTRAINT_EXCLUSION_LIST
130 | SyntaxKind::COPY_OPTION_LIST
131 | SyntaxKind::DATABASE_OPTION_LIST
132 | SyntaxKind::EXPLAIN_OPTION_LIST
133 | SyntaxKind::DROP_OP_CLASS_OPTION_LIST
134 | SyntaxKind::FDW_OPTION_LIST
135 | SyntaxKind::FUNCTION_SIG_LIST
136 | SyntaxKind::FUNC_OPTION_LIST
137 | SyntaxKind::GRANT_ROLE_OPTION_LIST
138 | SyntaxKind::GROUP_BY_LIST
139 | SyntaxKind::JSON_TABLE_COLUMN_LIST
140 | SyntaxKind::OPERATOR_CLASS_OPTION_LIST
141 | SyntaxKind::OPTION_ITEM_LIST
142 | SyntaxKind::OP_SIG_LIST
143 | SyntaxKind::PARTITION_ITEM_LIST
144 | SyntaxKind::PARTITION_LIST
145 | SyntaxKind::PATH_LIST
146 | SyntaxKind::REINDEX_OPTION_LIST
147 | SyntaxKind::RETURNING_OPTION_LIST
148 | SyntaxKind::REVOKE_COMMAND_LIST
149 | SyntaxKind::ROLE_OPTION_LIST
150 | SyntaxKind::ROLE_REF_LIST
151 | SyntaxKind::ROW_LIST
152 | SyntaxKind::RULE_STMT_LIST
153 | SyntaxKind::SEQUENCE_OPTION_LIST
154 | SyntaxKind::SET_COLUMN_LIST
155 | SyntaxKind::SET_EXPR_LIST
156 | SyntaxKind::SET_OPTIONS_LIST
157 | SyntaxKind::SORT_BY_LIST
158 | SyntaxKind::TABLE_AND_COLUMNS_LIST
159 | SyntaxKind::TABLE_LIST
160 | SyntaxKind::TARGET_LIST
161 | SyntaxKind::TRANSACTION_MODE_LIST
162 | SyntaxKind::TRIGGER_EVENT_LIST
163 | SyntaxKind::VACUUM_OPTION_LIST
164 | SyntaxKind::VARIANT_LIST
165 | SyntaxKind::EXPR_AS_NAME_LIST
166 | SyntaxKind::XML_COLUMN_OPTION_LIST
167 | SyntaxKind::XML_NAMESPACE_LIST
168 | SyntaxKind::XML_TABLE_COLUMN_LIST
169 | SyntaxKind::LABEL_AND_PROPERTIES_LIST
170 | SyntaxKind::PATH_PATTERN_LIST => Some(FoldKind::List),
171 _ => None,
172 }
173}
174
175fn contiguous_range_for_comment(
176 first: ast::Comment,
177 visited: &mut FxHashSet<ast::Comment>,
178) -> Option<TextRange> {
179 visited.insert(first.clone());
180
181 let group_kind = first.kind();
183 if !group_kind.is_line() {
184 return None;
185 }
186
187 let mut last = first.clone();
188 for element in first.syntax().siblings_with_tokens(Direction::Next) {
189 match element {
190 NodeOrToken::Token(token) => {
191 if let Some(ws) = ast::Whitespace::cast(token.clone())
192 && !ws.spans_multiple_lines()
193 {
194 continue;
196 }
197 if let Some(c) = ast::Comment::cast(token) {
198 visited.insert(c.clone());
199 last = c;
200 continue;
201 }
202 break;
206 }
207 NodeOrToken::Node(_) => break,
208 }
209 }
210
211 if first != last {
212 Some(TextRange::new(
213 first.syntax().text_range().start(),
214 last.syntax().text_range().end(),
215 ))
216 } else {
217 None
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use insta::assert_snapshot;
225
226 use crate::db::{Database, File};
227
228 use super::*;
229
230 fn fold_kind_str(kind: &FoldKind) -> &'static str {
231 match kind {
232 FoldKind::ArgList => "arglist",
233 FoldKind::Array => "array",
234 FoldKind::Comment => "comment",
235 FoldKind::FunctionCall => "function_call",
236 FoldKind::Join => "join",
237 FoldKind::List => "list",
238 FoldKind::Statement => "statement",
239 FoldKind::Subquery => "subquery",
240 FoldKind::Tuple => "tuple",
241 }
242 }
243
244 #[must_use]
245 fn check(sql: &str) -> String {
246 let db = Database::default();
247 let file = File::new(&db, sql.to_string().into());
248 let folds = folding_ranges(&db, file);
249
250 if folds.is_empty() {
251 return sql.to_string();
252 }
253
254 #[derive(PartialEq, Eq, PartialOrd, Ord)]
255 struct Event<'a> {
256 offset: usize,
257 is_end: bool,
258 kind: &'a str,
259 }
260
261 let mut events: Vec<Event<'_>> = vec![];
262 for fold in &folds {
263 let start: usize = fold.range.start().into();
264 let end: usize = fold.range.end().into();
265 let kind = fold_kind_str(&fold.kind);
266 events.push(Event {
267 offset: start,
268 is_end: false,
269 kind,
270 });
271 events.push(Event {
272 offset: end,
273 is_end: true,
274 kind,
275 });
276 }
277 events.sort();
278
279 let mut output = String::new();
280 let mut pos = 0usize;
281 for event in &events {
282 if event.offset > pos {
283 output.push_str(&sql[pos..event.offset]);
284 pos = event.offset;
285 }
286 if event.is_end {
287 output.push_str("</fold>");
288 } else {
289 output.push_str(&format!("<fold {}>", event.kind));
290 }
291 }
292 if pos < sql.len() {
293 output.push_str(&sql[pos..]);
294 }
295 output
296 }
297
298 #[test]
299 fn fold_create_table() {
300 assert_snapshot!(check("
301create table t (
302 id int,
303 name text
304);"), @"
305 <fold statement>create table t <fold arglist>(
306 id int,
307 name text
308 )</fold>;</fold>
309 ");
310 }
311
312 #[test]
313 fn fold_select() {
314 assert_snapshot!(check("
315select
316 id,
317 name
318from t;"), @"
319 <fold statement>select
320 <fold list>id,
321 name</fold>
322 from t;</fold>
323 ");
324 }
325
326 #[test]
327 fn do_not_fold_single_line_comment() {
328 assert_snapshot!(check("
329-- a comment
330select 1;"), @"
331 -- a comment
332 select 1;
333 ");
334 }
335
336 #[test]
337 fn fold_comments_does_not_apply_when_diff_comment_types() {
338 assert_snapshot!(check("
339/* first part */
340-- second part
341select 1;"), @"
342 /* first part */
343 -- second part
344 select 1;
345 ");
346 }
347
348 #[test]
349 fn fold_comments_and_multi_statements() {
350 assert_snapshot!(check("
351-- this is
352
353-- a comment
354-- with some more
355select a, b, 3
356 from t
357 where c > 10;"), @"
358 -- this is
359
360 <fold comment>-- a comment
361 -- with some more</fold>
362 <fold statement>select a, b, 3
363 from t
364 where c > 10;</fold>
365 ");
366 }
367
368 #[test]
369 fn fold_comments_does_not_apply_when_whitespace_between() {
370 assert_snapshot!(check("
371-- this is
372
373-- a comment
374-- with some more
375select 1;"), @"
376 -- this is
377
378 <fold comment>-- a comment
379 -- with some more</fold>
380 select 1;
381 ");
382 }
383
384 #[test]
385 fn fold_multiline_comments() {
386 assert_snapshot!(check("
387-- this is
388-- a comment
389select 1;"), @"
390 <fold comment>-- this is
391 -- a comment</fold>
392 select 1;
393 ");
394 }
395
396 #[test]
397 fn fold_single_line_no_fold() {
398 assert_snapshot!(check("select 1;"), @"select 1;");
399 }
400
401 #[test]
402 fn fold_subquery() {
403 assert_snapshot!(check("
404select * from (
405 select id from t
406);"), @"
407 <fold statement>select * from <fold statement>(
408 select id from t
409 )</fold>;</fold>
410 ");
411 }
412
413 #[test]
414 fn fold_case_when() {
415 assert_snapshot!(check("
416select
417 case
418 when x = 1 then 'a'
419 when x = 2 then 'b'
420 end
421from t;"), @"
422 <fold statement>select
423 <fold list>case
424 <fold list>when x = 1 then 'a'
425 when x = 2 then 'b'</fold>
426 end</fold>
427 from t;</fold>
428 ");
429 }
430
431 #[test]
432 fn fold_join() {
433 assert_snapshot!(check("
434select *
435from a
436join b
437 on a.id = b.id;"), @"
438 <fold statement>select *
439 from a
440 <fold join>join b
441 on a.id = b.id</fold>;</fold>
442 ");
443 }
444
445 #[test]
446 fn fold_array_literal() {
447 assert_snapshot!(check("
448select * from t where
449 x = any(array[
450 1,
451 2,
452 3
453 ]);"), @"
454 <fold statement>select * from t where
455 x = <fold function_call>any(<fold array>array[
456 1,
457 2,
458 3
459 ]</fold>)</fold>;</fold>
460 ");
461 }
462
463 #[test]
464 fn fold_tuple_literal() {
465 assert_snapshot!(check("
466select (
467 1,
468 2,
469 3
470);"), @"
471 <fold statement>select <fold list><fold tuple>(
472 1,
473 2,
474 3
475 )</fold></fold>;</fold>
476 ");
477 }
478
479 #[test]
480 fn fold_tuple_bin_expr() {
481 assert_snapshot!(check("
482select * from x
483 where z in (
484 1,
485 2,
486 3,
487 4,
488 5
489 );
490"), @"
491 <fold statement>select * from x
492 where z in <fold tuple>(
493 1,
494 2,
495 3,
496 4,
497 5
498 )</fold>;</fold>
499 ");
500 }
501
502 #[test]
503 fn fold_function_call() {
504 assert_snapshot!(check("
505select coalesce(
506 a,
507 b,
508 c
509);"), @"
510 <fold statement>select <fold function_call><fold list>coalesce<fold arglist>(
511 a,
512 b,
513 c
514 )</fold></fold></fold>;</fold>
515 ");
516 }
517
518 #[test]
519 fn fold_create_enum() {
520 assert_snapshot!(check("
521create type status as enum (
522 'active',
523 'inactive'
524);"), @"
525 <fold statement>create type status as enum <fold list>(
526 'active',
527 'inactive'
528 )</fold>;</fold>
529 ");
530 }
531
532 #[test]
533 fn fold_insert_values() {
534 assert_snapshot!(check("
535insert into t (id, name)
536values
537 (1, 'a'),
538 (2, 'b');"), @"
539 <fold statement>insert into t (id, name)
540 <fold statement>values
541 <fold list>(1, 'a'),
542 (2, 'b')</fold></fold>;</fold>
543 ");
544 }
545
546 #[test]
547 fn no_fold_single_line_create_table() {
548 assert_snapshot!(check("create table t (id int);"), @"create table t (id int);");
549 }
550
551 #[test]
552 fn list_variants() {
553 let unhandled_list_kinds: Vec<SyntaxKind> = (0..SyntaxKind::__LAST as u16)
554 .map(SyntaxKind::from)
555 .filter(|kind| format!("{kind:?}").ends_with("_LIST"))
556 .filter(|kind| fold_kind(*kind).is_none())
557 .collect();
558
559 assert_eq!(
560 unhandled_list_kinds,
561 vec![],
562 "All _LIST SyntaxKind variants should be handled in fold_kind"
563 );
564 }
565}