solidity_language_server/
folding.rs1use tower_lsp::lsp_types::{FoldingRange, FoldingRangeKind};
2use tree_sitter::{Node, Parser};
3
4pub fn folding_ranges(source: &str) -> Vec<FoldingRange> {
9 let tree = match parse(source) {
10 Some(t) => t,
11 None => return vec![],
12 };
13 let mut ranges = Vec::new();
14 collect_folding_ranges(tree.root_node(), source, &mut ranges);
15 collect_comment_folds(tree.root_node(), source, &mut ranges);
16 collect_import_folds(tree.root_node(), &mut ranges);
17 ranges
18}
19
20fn collect_folding_ranges(node: Node, source: &str, out: &mut Vec<FoldingRange>) {
22 match node.kind() {
23 "contract_declaration" | "interface_declaration" | "library_declaration" => {
26 if let Some(body) = find_child(node, "contract_body") {
27 push_brace_fold(body, None, out);
28 walk_children(body, source, out);
29 }
30 return;
31 }
32 "struct_declaration" => {
33 if let Some(body) = find_child(node, "struct_body") {
34 push_brace_fold(body, None, out);
35 }
36 return;
37 }
38 "enum_declaration" => {
39 if let Some(body) = find_child(node, "enum_body") {
40 push_brace_fold(body, None, out);
41 }
42 return;
43 }
44
45 "function_definition"
48 | "constructor_definition"
49 | "modifier_definition"
50 | "fallback_receive_definition" => {
51 if let Some(body) = find_child(node, "function_body") {
52 push_brace_fold(body, None, out);
53 walk_children(body, source, out);
54 }
55 return;
56 }
57
58 "block_statement" | "unchecked_block" => {
60 push_brace_fold(node, None, out);
61 }
62
63 "if_statement" | "for_statement" | "while_statement" | "do_while_statement"
66 | "try_statement" => {}
67
68 "assembly_statement" => {
70 if let Some(body) = find_child(node, "yul_block") {
71 push_brace_fold(body, None, out);
72 }
73 }
74
75 "event_definition" | "error_declaration" => {
77 push_multiline_fold(node, None, out);
78 }
79
80 _ => {}
81 }
82
83 walk_children(node, source, out);
84}
85
86fn walk_children(node: Node, source: &str, out: &mut Vec<FoldingRange>) {
87 let mut cursor = node.walk();
88 for child in node.children(&mut cursor) {
89 if child.is_named() {
90 collect_folding_ranges(child, source, out);
91 }
92 }
93}
94
95fn collect_comment_folds(root: Node, source: &str, out: &mut Vec<FoldingRange>) {
101 let mut cursor = root.walk();
102 let children: Vec<Node> = root
103 .children(&mut cursor)
104 .filter(|c| c.kind() == "comment")
105 .collect();
106
107 let mut i = 0;
108 while i < children.len() {
109 let node = children[i];
110 let text = &source[node.byte_range()];
111 let start_line = node.start_position().row as u32;
112 let end_line = node.end_position().row as u32;
113
114 if text.starts_with("/*") {
115 if end_line > start_line {
117 out.push(FoldingRange {
118 start_line,
119 start_character: Some(node.start_position().column as u32),
120 end_line,
121 end_character: Some(node.end_position().column as u32),
122 kind: Some(FoldingRangeKind::Comment),
123 collapsed_text: None,
124 });
125 }
126 i += 1;
127 } else if text.starts_with("//") {
128 let group_start = start_line;
130 let mut group_end = end_line;
131 let mut j = i + 1;
132 while j < children.len() {
133 let next = children[j];
134 let next_text = &source[next.byte_range()];
135 let next_start = next.start_position().row as u32;
136 if next_text.starts_with("//") && next_start == group_end + 1 {
137 group_end = next.end_position().row as u32;
138 j += 1;
139 } else {
140 break;
141 }
142 }
143 if group_end > group_start {
144 out.push(FoldingRange {
145 start_line: group_start,
146 start_character: Some(node.start_position().column as u32),
147 end_line: group_end,
148 end_character: None,
149 kind: Some(FoldingRangeKind::Comment),
150 collapsed_text: None,
151 });
152 }
153 i = j;
154 } else {
155 i += 1;
156 }
157 }
158
159 let mut cursor2 = root.walk();
161 for child in root.children(&mut cursor2) {
162 if child.is_named()
163 && has_body(child)
164 && let Some(body) = find_body(child)
165 {
166 collect_comment_folds(body, source, out);
167 }
168 }
169}
170
171fn collect_import_folds(root: Node, out: &mut Vec<FoldingRange>) {
173 let mut cursor = root.walk();
174 let children: Vec<Node> = root
175 .children(&mut cursor)
176 .filter(|c| c.is_named())
177 .collect();
178
179 let mut i = 0;
180 while i < children.len() {
181 if children[i].kind() == "import_directive" {
182 let start_line = children[i].start_position().row as u32;
183 let start_char = children[i].start_position().column as u32;
184 let mut end_line = children[i].end_position().row as u32;
185
186 if end_line > start_line {
188 out.push(FoldingRange {
189 start_line,
190 start_character: Some(start_char),
191 end_line,
192 end_character: Some(children[i].end_position().column as u32),
193 kind: Some(FoldingRangeKind::Imports),
194 collapsed_text: None,
195 });
196 }
197
198 let mut j = i + 1;
200 while j < children.len() && children[j].kind() == "import_directive" {
201 end_line = children[j].end_position().row as u32;
202 j += 1;
203 }
204 if j > i + 1 {
205 out.push(FoldingRange {
207 start_line,
208 start_character: Some(start_char),
209 end_line,
210 end_character: None,
211 kind: Some(FoldingRangeKind::Imports),
212 collapsed_text: None,
213 });
214 }
215 i = j;
216 } else {
217 i += 1;
218 }
219 }
220}
221
222fn parse(source: &str) -> Option<tree_sitter::Tree> {
225 let mut parser = Parser::new();
226 parser
227 .set_language(&tree_sitter_solidity::LANGUAGE.into())
228 .expect("failed to load Solidity grammar");
229 parser.parse(source, None)
230}
231
232fn push_brace_fold(node: Node, kind: Option<FoldingRangeKind>, out: &mut Vec<FoldingRange>) {
235 let start_line = node.start_position().row as u32;
236 let end_line = node.end_position().row as u32;
237 if end_line > start_line {
238 out.push(FoldingRange {
239 start_line,
240 start_character: Some(node.start_position().column as u32),
241 end_line,
242 end_character: Some(node.end_position().column as u32),
243 kind,
244 collapsed_text: None,
245 });
246 }
247}
248
249fn push_multiline_fold(node: Node, kind: Option<FoldingRangeKind>, out: &mut Vec<FoldingRange>) {
251 let start_line = node.start_position().row as u32;
252 let end_line = node.end_position().row as u32;
253 if end_line > start_line {
254 out.push(FoldingRange {
255 start_line,
256 start_character: Some(node.start_position().column as u32),
257 end_line,
258 end_character: Some(node.end_position().column as u32),
259 kind,
260 collapsed_text: None,
261 });
262 }
263}
264
265fn find_child<'a>(node: Node<'a>, kind: &str) -> Option<Node<'a>> {
266 let mut cursor = node.walk();
267 node.children(&mut cursor).find(|c| c.kind() == kind)
268}
269
270fn has_body(node: Node) -> bool {
271 matches!(
272 node.kind(),
273 "contract_declaration"
274 | "interface_declaration"
275 | "library_declaration"
276 | "struct_declaration"
277 | "enum_declaration"
278 )
279}
280
281fn find_body(node: Node) -> Option<Node> {
282 match node.kind() {
283 "contract_declaration" | "interface_declaration" | "library_declaration" => {
284 find_child(node, "contract_body")
285 }
286 "struct_declaration" => find_child(node, "struct_body"),
287 "enum_declaration" => find_child(node, "enum_body"),
288 _ => None,
289 }
290}
291
292#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_empty_source() {
300 assert!(folding_ranges("").is_empty());
301 }
302
303 #[test]
304 fn test_single_line_contract() {
305 let source = "contract Foo {}";
307 let ranges = folding_ranges(source);
308 assert!(ranges.is_empty(), "single-line contract should not fold");
309 }
310
311 #[test]
312 fn test_contract_body_fold() {
313 let source = r#"
314contract Counter {
315 uint256 public count;
316 function increment() public {
317 count += 1;
318 }
319}
320"#;
321 let ranges = folding_ranges(source);
322 let contract_folds: Vec<_> = ranges.iter().filter(|r| r.kind.is_none()).collect();
324 assert!(
325 contract_folds.len() >= 2,
326 "expected at least 2 region folds (contract body + function body), got {}",
327 contract_folds.len()
328 );
329 }
330
331 #[test]
332 fn test_function_body_fold() {
333 let source = r#"
334contract Foo {
335 function bar() public {
336 uint256 x = 1;
337 uint256 y = 2;
338 }
339}
340"#;
341 let ranges = folding_ranges(source);
342 let func_fold = ranges
345 .iter()
346 .find(|r| r.start_line == 2 && r.end_line == 5 && r.kind.is_none());
347 assert!(
348 func_fold.is_some(),
349 "expected fold for function body, got ranges: {:?}",
350 ranges
351 .iter()
352 .map(|r| (r.start_line, r.end_line, &r.kind))
353 .collect::<Vec<_>>()
354 );
355 }
356
357 #[test]
358 fn test_struct_fold() {
359 let source = r#"
360struct Info {
361 string name;
362 uint256 value;
363 address owner;
364}
365"#;
366 let ranges = folding_ranges(source);
367 let struct_fold = ranges.iter().find(|r| r.start_line == 1);
368 assert!(struct_fold.is_some(), "expected fold for struct body");
369 }
370
371 #[test]
372 fn test_enum_fold() {
373 let source = r#"
374enum Status {
375 Active,
376 Paused,
377 Stopped
378}
379"#;
380 let ranges = folding_ranges(source);
381 let enum_fold = ranges.iter().find(|r| r.start_line == 1);
382 assert!(enum_fold.is_some(), "expected fold for enum body");
383 }
384
385 #[test]
386 fn test_block_comment_fold() {
387 let source = r#"
388/*
389 * This is a multi-line
390 * block comment
391 */
392contract Foo {}
393"#;
394 let ranges = folding_ranges(source);
395 let comment_folds: Vec<_> = ranges
396 .iter()
397 .filter(|r| r.kind == Some(FoldingRangeKind::Comment))
398 .collect();
399 assert!(
400 !comment_folds.is_empty(),
401 "expected a comment fold for block comment"
402 );
403 assert_eq!(comment_folds[0].start_line, 1);
404 assert_eq!(comment_folds[0].end_line, 4);
405 }
406
407 #[test]
408 fn test_consecutive_line_comments_fold() {
409 let source = r#"// line 1
410// line 2
411// line 3
412contract Foo {}
413"#;
414 let ranges = folding_ranges(source);
415 let comment_folds: Vec<_> = ranges
416 .iter()
417 .filter(|r| r.kind == Some(FoldingRangeKind::Comment))
418 .collect();
419 assert!(
420 !comment_folds.is_empty(),
421 "expected a fold for consecutive line comments"
422 );
423 assert_eq!(comment_folds[0].start_line, 0);
424 assert_eq!(comment_folds[0].end_line, 2);
425 }
426
427 #[test]
428 fn test_single_line_comment_no_fold() {
429 let source = r#"
430// just one line
431contract Foo {}
432"#;
433 let ranges = folding_ranges(source);
434 let comment_folds: Vec<_> = ranges
435 .iter()
436 .filter(|r| r.kind == Some(FoldingRangeKind::Comment))
437 .collect();
438 assert!(
439 comment_folds.is_empty(),
440 "single line comment should not produce a fold"
441 );
442 }
443
444 #[test]
445 fn test_import_group_fold() {
446 let source = r#"
447import "./A.sol";
448import "./B.sol";
449import "./C.sol";
450
451contract Foo {}
452"#;
453 let ranges = folding_ranges(source);
454 let import_folds: Vec<_> = ranges
455 .iter()
456 .filter(|r| r.kind == Some(FoldingRangeKind::Imports))
457 .collect();
458 assert!(
459 !import_folds.is_empty(),
460 "expected an import group fold for consecutive imports"
461 );
462 let group = import_folds
464 .iter()
465 .find(|r| r.start_line == 1 && r.end_line == 3);
466 assert!(group.is_some(), "expected group fold spanning lines 1-3");
467 }
468
469 #[test]
470 fn test_multiline_import_fold() {
471 let source = r#"
472import {
473 Foo,
474 Bar,
475 Baz
476} from "./Lib.sol";
477"#;
478 let ranges = folding_ranges(source);
479 let import_folds: Vec<_> = ranges
480 .iter()
481 .filter(|r| r.kind == Some(FoldingRangeKind::Imports))
482 .collect();
483 assert!(
484 !import_folds.is_empty(),
485 "expected fold for multi-line import"
486 );
487 }
488
489 #[test]
490 fn test_shop_sol() {
491 let source = std::fs::read_to_string("example/Shop.sol").unwrap();
492 let ranges = folding_ranges(&source);
493
494 assert!(
496 ranges.len() >= 10,
497 "Shop.sol should have at least 10 folding ranges, got {}",
498 ranges.len()
499 );
500
501 let lib_fold = ranges.iter().find(|r| r.start_line == 22);
503 assert!(
504 lib_fold.is_some(),
505 "expected fold starting at library body (line 22)"
506 );
507 }
508
509 #[test]
510 fn test_interface_fold() {
511 let source = r#"
512interface IToken {
513 function transfer(address to, uint256 amount) external returns (bool);
514 function balanceOf(address account) external view returns (uint256);
515}
516"#;
517 let ranges = folding_ranges(source);
518 let interface_fold = ranges.iter().find(|r| r.start_line == 1);
519 assert!(interface_fold.is_some(), "expected fold for interface body");
520 }
521
522 #[test]
523 fn test_library_fold() {
524 let source = r#"
525library SafeMath {
526 function add(uint256 a, uint256 b) internal pure returns (uint256) {
527 return a + b;
528 }
529}
530"#;
531 let ranges = folding_ranges(source);
532 assert!(
533 ranges.len() >= 2,
534 "library should produce at least 2 folds (body + function)"
535 );
536 }
537
538 #[test]
539 fn test_nested_blocks_fold() {
540 let source = r#"
541contract Foo {
542 function bar() public {
543 if (true) {
544 uint256 x = 1;
545 }
546 for (uint256 i = 0; i < 10; i++) {
547 uint256 y = i;
548 }
549 }
550}
551"#;
552 let ranges = folding_ranges(source);
553 let region_folds: Vec<_> = ranges.iter().filter(|r| r.kind.is_none()).collect();
555 assert!(
556 region_folds.len() >= 4,
557 "expected at least 4 folds for nested blocks, got {}",
558 region_folds.len()
559 );
560 }
561
562 #[test]
563 fn test_modifier_fold() {
564 let source = r#"
565contract Foo {
566 modifier onlyOwner() {
567 require(msg.sender == owner);
568 _;
569 }
570}
571"#;
572 let ranges = folding_ranges(source);
573 let modifier_fold = ranges.iter().find(|r| r.start_line == 2);
575 assert!(modifier_fold.is_some(), "expected fold for modifier body");
576 }
577
578 #[test]
579 fn test_constructor_fold() {
580 let source = r#"
581contract Foo {
582 constructor() {
583 owner = msg.sender;
584 }
585}
586"#;
587 let ranges = folding_ranges(source);
588 let ctor_fold = ranges.iter().find(|r| r.start_line == 2);
589 assert!(ctor_fold.is_some(), "expected fold for constructor body");
590 }
591
592 #[test]
593 fn test_inner_block_comment_fold() {
594 let source = r#"
595contract Foo {
596 /*
597 * This is a comment
598 * inside a contract
599 */
600 function bar() public {}
601}
602"#;
603 let ranges = folding_ranges(source);
604 let comment_folds: Vec<_> = ranges
605 .iter()
606 .filter(|r| r.kind == Some(FoldingRangeKind::Comment))
607 .collect();
608 assert!(
609 !comment_folds.is_empty(),
610 "expected comment fold inside contract body"
611 );
612 }
613}