Skip to main content

solidity_language_server/
folding.rs

1use tower_lsp::lsp_types::{FoldingRange, FoldingRangeKind};
2use tree_sitter::{Node, Parser};
3
4/// Extract folding ranges from Solidity source using tree-sitter.
5///
6/// Returns ranges for contracts, functions, structs, enums, block statements,
7/// multi-line comments, consecutive single-line comments, and import groups.
8pub fn folding_ranges(source: &str) -> Vec<FoldingRange> {
9    let tree = match parse(source) {
10        Some(t) => t,
11        None => return vec![],
12    };
13    let mut ranges = Vec::new();
14    collect_folding_ranges(tree.root_node(), source, &mut ranges);
15    collect_comment_folds(tree.root_node(), source, &mut ranges);
16    collect_import_folds(tree.root_node(), &mut ranges);
17    ranges
18}
19
20/// Recursively walk the tree and emit folding ranges for multi-line nodes.
21fn collect_folding_ranges(node: Node, source: &str, out: &mut Vec<FoldingRange>) {
22    match node.kind() {
23        // Top-level declarations with bodies — emit a fold for the body then
24        // recurse into the body's children (functions, state vars, etc.).
25        "contract_declaration" | "interface_declaration" | "library_declaration" => {
26            if let Some(body) = find_child(node, "contract_body") {
27                push_brace_fold(body, None, out);
28                walk_children(body, source, out);
29            }
30            return;
31        }
32        "struct_declaration" => {
33            if let Some(body) = find_child(node, "struct_body") {
34                push_brace_fold(body, None, out);
35            }
36            return;
37        }
38        "enum_declaration" => {
39            if let Some(body) = find_child(node, "enum_body") {
40                push_brace_fold(body, None, out);
41            }
42            return;
43        }
44
45        // Functions, constructors, modifiers, fallback/receive — emit a fold
46        // for the function body then recurse into it for nested blocks.
47        "function_definition"
48        | "constructor_definition"
49        | "modifier_definition"
50        | "fallback_receive_definition" => {
51            if let Some(body) = find_child(node, "function_body") {
52                push_brace_fold(body, None, out);
53                walk_children(body, source, out);
54            }
55            return;
56        }
57
58        // Block statements inside function bodies
59        "block_statement" | "unchecked_block" => {
60            push_brace_fold(node, None, out);
61        }
62
63        // Control-flow with braces — recurse into children which will emit
64        // folds for their block_statement bodies.
65        "if_statement" | "for_statement" | "while_statement" | "do_while_statement"
66        | "try_statement" => {}
67
68        // Assembly blocks
69        "assembly_statement" => {
70            if let Some(body) = find_child(node, "yul_block") {
71                push_brace_fold(body, None, out);
72            }
73        }
74
75        // Event/error with multi-line parameter lists
76        "event_definition" | "error_declaration" => {
77            push_multiline_fold(node, None, out);
78        }
79
80        _ => {}
81    }
82
83    walk_children(node, source, out);
84}
85
86fn walk_children(node: Node, source: &str, out: &mut Vec<FoldingRange>) {
87    let mut cursor = node.walk();
88    for child in node.children(&mut cursor) {
89        if child.is_named() {
90            collect_folding_ranges(child, source, out);
91        }
92    }
93}
94
95/// Collect folding ranges for comments.
96///
97/// - Multi-line block comments (`/* ... */`) get a Comment fold.
98/// - Consecutive single-line comments (`// ...`) on adjacent lines are grouped
99///   into a single Comment fold.
100fn collect_comment_folds(root: Node, source: &str, out: &mut Vec<FoldingRange>) {
101    let mut cursor = root.walk();
102    let children: Vec<Node> = root
103        .children(&mut cursor)
104        .filter(|c| c.kind() == "comment")
105        .collect();
106
107    let mut i = 0;
108    while i < children.len() {
109        let node = children[i];
110        let text = &source[node.byte_range()];
111        let start_line = node.start_position().row as u32;
112        let end_line = node.end_position().row as u32;
113
114        if text.starts_with("/*") {
115            // Multi-line block comment
116            if end_line > start_line {
117                out.push(FoldingRange {
118                    start_line,
119                    start_character: Some(node.start_position().column as u32),
120                    end_line,
121                    end_character: Some(node.end_position().column as u32),
122                    kind: Some(FoldingRangeKind::Comment),
123                    collapsed_text: None,
124                });
125            }
126            i += 1;
127        } else if text.starts_with("//") {
128            // Group consecutive single-line comments
129            let group_start = start_line;
130            let mut group_end = end_line;
131            let mut j = i + 1;
132            while j < children.len() {
133                let next = children[j];
134                let next_text = &source[next.byte_range()];
135                let next_start = next.start_position().row as u32;
136                if next_text.starts_with("//") && next_start == group_end + 1 {
137                    group_end = next.end_position().row as u32;
138                    j += 1;
139                } else {
140                    break;
141                }
142            }
143            if group_end > group_start {
144                out.push(FoldingRange {
145                    start_line: group_start,
146                    start_character: Some(node.start_position().column as u32),
147                    end_line: group_end,
148                    end_character: None,
149                    kind: Some(FoldingRangeKind::Comment),
150                    collapsed_text: None,
151                });
152            }
153            i = j;
154        } else {
155            i += 1;
156        }
157    }
158
159    // Also recurse into contract/struct/enum bodies for inner comments
160    let mut cursor2 = root.walk();
161    for child in root.children(&mut cursor2) {
162        if child.is_named()
163            && has_body(child)
164            && let Some(body) = find_body(child)
165        {
166            collect_comment_folds(body, source, out);
167        }
168    }
169}
170
171/// Group consecutive `import_directive` nodes into a single Imports fold.
172fn collect_import_folds(root: Node, out: &mut Vec<FoldingRange>) {
173    let mut cursor = root.walk();
174    let children: Vec<Node> = root
175        .children(&mut cursor)
176        .filter(|c| c.is_named())
177        .collect();
178
179    let mut i = 0;
180    while i < children.len() {
181        if children[i].kind() == "import_directive" {
182            let start_line = children[i].start_position().row as u32;
183            let start_char = children[i].start_position().column as u32;
184            let mut end_line = children[i].end_position().row as u32;
185
186            // Also fold individual multi-line imports (e.g. `import { A, B, C } from "...";`)
187            if end_line > start_line {
188                out.push(FoldingRange {
189                    start_line,
190                    start_character: Some(start_char),
191                    end_line,
192                    end_character: Some(children[i].end_position().column as u32),
193                    kind: Some(FoldingRangeKind::Imports),
194                    collapsed_text: None,
195                });
196            }
197
198            // Group consecutive imports
199            let mut j = i + 1;
200            while j < children.len() && children[j].kind() == "import_directive" {
201                end_line = children[j].end_position().row as u32;
202                j += 1;
203            }
204            if j > i + 1 {
205                // Multiple consecutive imports — create a group fold
206                out.push(FoldingRange {
207                    start_line,
208                    start_character: Some(start_char),
209                    end_line,
210                    end_character: None,
211                    kind: Some(FoldingRangeKind::Imports),
212                    collapsed_text: None,
213                });
214            }
215            i = j;
216        } else {
217            i += 1;
218        }
219    }
220}
221
222// ── Helpers ────────────────────────────────────────────────────────────────
223
224fn parse(source: &str) -> Option<tree_sitter::Tree> {
225    let mut parser = Parser::new();
226    parser
227        .set_language(&tree_sitter_solidity::LANGUAGE.into())
228        .expect("failed to load Solidity grammar");
229    parser.parse(source, None)
230}
231
232/// Push a fold for a brace-delimited node (e.g. `{ ... }`).
233/// Only emits a fold when the node spans multiple lines.
234fn push_brace_fold(node: Node, kind: Option<FoldingRangeKind>, out: &mut Vec<FoldingRange>) {
235    let start_line = node.start_position().row as u32;
236    let end_line = node.end_position().row as u32;
237    if end_line > start_line {
238        out.push(FoldingRange {
239            start_line,
240            start_character: Some(node.start_position().column as u32),
241            end_line,
242            end_character: Some(node.end_position().column as u32),
243            kind,
244            collapsed_text: None,
245        });
246    }
247}
248
249/// Push a fold for any multi-line node (events, errors with long param lists).
250fn push_multiline_fold(node: Node, kind: Option<FoldingRangeKind>, out: &mut Vec<FoldingRange>) {
251    let start_line = node.start_position().row as u32;
252    let end_line = node.end_position().row as u32;
253    if end_line > start_line {
254        out.push(FoldingRange {
255            start_line,
256            start_character: Some(node.start_position().column as u32),
257            end_line,
258            end_character: Some(node.end_position().column as u32),
259            kind,
260            collapsed_text: None,
261        });
262    }
263}
264
265fn find_child<'a>(node: Node<'a>, kind: &str) -> Option<Node<'a>> {
266    let mut cursor = node.walk();
267    node.children(&mut cursor).find(|c| c.kind() == kind)
268}
269
270fn has_body(node: Node) -> bool {
271    matches!(
272        node.kind(),
273        "contract_declaration"
274            | "interface_declaration"
275            | "library_declaration"
276            | "struct_declaration"
277            | "enum_declaration"
278    )
279}
280
281fn find_body(node: Node) -> Option<Node> {
282    match node.kind() {
283        "contract_declaration" | "interface_declaration" | "library_declaration" => {
284            find_child(node, "contract_body")
285        }
286        "struct_declaration" => find_child(node, "struct_body"),
287        "enum_declaration" => find_child(node, "enum_body"),
288        _ => None,
289    }
290}
291
292// ── Tests ──────────────────────────────────────────────────────────────────
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_empty_source() {
300        assert!(folding_ranges("").is_empty());
301    }
302
303    #[test]
304    fn test_single_line_contract() {
305        // No folds for single-line constructs
306        let source = "contract Foo {}";
307        let ranges = folding_ranges(source);
308        assert!(ranges.is_empty(), "single-line contract should not fold");
309    }
310
311    #[test]
312    fn test_contract_body_fold() {
313        let source = r#"
314contract Counter {
315    uint256 public count;
316    function increment() public {
317        count += 1;
318    }
319}
320"#;
321        let ranges = folding_ranges(source);
322        // Should have folds for: contract body, function body
323        let contract_folds: Vec<_> = ranges.iter().filter(|r| r.kind.is_none()).collect();
324        assert!(
325            contract_folds.len() >= 2,
326            "expected at least 2 region folds (contract body + function body), got {}",
327            contract_folds.len()
328        );
329    }
330
331    #[test]
332    fn test_function_body_fold() {
333        let source = r#"
334contract Foo {
335    function bar() public {
336        uint256 x = 1;
337        uint256 y = 2;
338    }
339}
340"#;
341        let ranges = folding_ranges(source);
342        // The function body `{ ... }` starts on line 2 (same line as the
343        // function signature) and ends on line 5 (`}`).
344        let func_fold = ranges
345            .iter()
346            .find(|r| r.start_line == 2 && r.end_line == 5 && r.kind.is_none());
347        assert!(
348            func_fold.is_some(),
349            "expected fold for function body, got ranges: {:?}",
350            ranges
351                .iter()
352                .map(|r| (r.start_line, r.end_line, &r.kind))
353                .collect::<Vec<_>>()
354        );
355    }
356
357    #[test]
358    fn test_struct_fold() {
359        let source = r#"
360struct Info {
361    string name;
362    uint256 value;
363    address owner;
364}
365"#;
366        let ranges = folding_ranges(source);
367        let struct_fold = ranges.iter().find(|r| r.start_line == 1);
368        assert!(struct_fold.is_some(), "expected fold for struct body");
369    }
370
371    #[test]
372    fn test_enum_fold() {
373        let source = r#"
374enum Status {
375    Active,
376    Paused,
377    Stopped
378}
379"#;
380        let ranges = folding_ranges(source);
381        let enum_fold = ranges.iter().find(|r| r.start_line == 1);
382        assert!(enum_fold.is_some(), "expected fold for enum body");
383    }
384
385    #[test]
386    fn test_block_comment_fold() {
387        let source = r#"
388/*
389 * This is a multi-line
390 * block comment
391 */
392contract Foo {}
393"#;
394        let ranges = folding_ranges(source);
395        let comment_folds: Vec<_> = ranges
396            .iter()
397            .filter(|r| r.kind == Some(FoldingRangeKind::Comment))
398            .collect();
399        assert!(
400            !comment_folds.is_empty(),
401            "expected a comment fold for block comment"
402        );
403        assert_eq!(comment_folds[0].start_line, 1);
404        assert_eq!(comment_folds[0].end_line, 4);
405    }
406
407    #[test]
408    fn test_consecutive_line_comments_fold() {
409        let source = r#"// line 1
410// line 2
411// line 3
412contract Foo {}
413"#;
414        let ranges = folding_ranges(source);
415        let comment_folds: Vec<_> = ranges
416            .iter()
417            .filter(|r| r.kind == Some(FoldingRangeKind::Comment))
418            .collect();
419        assert!(
420            !comment_folds.is_empty(),
421            "expected a fold for consecutive line comments"
422        );
423        assert_eq!(comment_folds[0].start_line, 0);
424        assert_eq!(comment_folds[0].end_line, 2);
425    }
426
427    #[test]
428    fn test_single_line_comment_no_fold() {
429        let source = r#"
430// just one line
431contract Foo {}
432"#;
433        let ranges = folding_ranges(source);
434        let comment_folds: Vec<_> = ranges
435            .iter()
436            .filter(|r| r.kind == Some(FoldingRangeKind::Comment))
437            .collect();
438        assert!(
439            comment_folds.is_empty(),
440            "single line comment should not produce a fold"
441        );
442    }
443
444    #[test]
445    fn test_import_group_fold() {
446        let source = r#"
447import "./A.sol";
448import "./B.sol";
449import "./C.sol";
450
451contract Foo {}
452"#;
453        let ranges = folding_ranges(source);
454        let import_folds: Vec<_> = ranges
455            .iter()
456            .filter(|r| r.kind == Some(FoldingRangeKind::Imports))
457            .collect();
458        assert!(
459            !import_folds.is_empty(),
460            "expected an import group fold for consecutive imports"
461        );
462        // The group fold should span from first to last import
463        let group = import_folds
464            .iter()
465            .find(|r| r.start_line == 1 && r.end_line == 3);
466        assert!(group.is_some(), "expected group fold spanning lines 1-3");
467    }
468
469    #[test]
470    fn test_multiline_import_fold() {
471        let source = r#"
472import {
473    Foo,
474    Bar,
475    Baz
476} from "./Lib.sol";
477"#;
478        let ranges = folding_ranges(source);
479        let import_folds: Vec<_> = ranges
480            .iter()
481            .filter(|r| r.kind == Some(FoldingRangeKind::Imports))
482            .collect();
483        assert!(
484            !import_folds.is_empty(),
485            "expected fold for multi-line import"
486        );
487    }
488
489    #[test]
490    fn test_shop_sol() {
491        let source = std::fs::read_to_string("example/Shop.sol").unwrap();
492        let ranges = folding_ranges(&source);
493
494        // Shop.sol has library + contract bodies, many functions, comments, etc.
495        assert!(
496            ranges.len() >= 10,
497            "Shop.sol should have at least 10 folding ranges, got {}",
498            ranges.len()
499        );
500
501        // Library body fold (Transaction at line 22)
502        let lib_fold = ranges.iter().find(|r| r.start_line == 22);
503        assert!(
504            lib_fold.is_some(),
505            "expected fold starting at library body (line 22)"
506        );
507    }
508
509    #[test]
510    fn test_interface_fold() {
511        let source = r#"
512interface IToken {
513    function transfer(address to, uint256 amount) external returns (bool);
514    function balanceOf(address account) external view returns (uint256);
515}
516"#;
517        let ranges = folding_ranges(source);
518        let interface_fold = ranges.iter().find(|r| r.start_line == 1);
519        assert!(interface_fold.is_some(), "expected fold for interface body");
520    }
521
522    #[test]
523    fn test_library_fold() {
524        let source = r#"
525library SafeMath {
526    function add(uint256 a, uint256 b) internal pure returns (uint256) {
527        return a + b;
528    }
529}
530"#;
531        let ranges = folding_ranges(source);
532        assert!(
533            ranges.len() >= 2,
534            "library should produce at least 2 folds (body + function)"
535        );
536    }
537
538    #[test]
539    fn test_nested_blocks_fold() {
540        let source = r#"
541contract Foo {
542    function bar() public {
543        if (true) {
544            uint256 x = 1;
545        }
546        for (uint256 i = 0; i < 10; i++) {
547            uint256 y = i;
548        }
549    }
550}
551"#;
552        let ranges = folding_ranges(source);
553        // Should have folds for: contract body, function body, if block, for block
554        let region_folds: Vec<_> = ranges.iter().filter(|r| r.kind.is_none()).collect();
555        assert!(
556            region_folds.len() >= 4,
557            "expected at least 4 folds for nested blocks, got {}",
558            region_folds.len()
559        );
560    }
561
562    #[test]
563    fn test_modifier_fold() {
564        let source = r#"
565contract Foo {
566    modifier onlyOwner() {
567        require(msg.sender == owner);
568        _;
569    }
570}
571"#;
572        let ranges = folding_ranges(source);
573        // Should fold the modifier body
574        let modifier_fold = ranges.iter().find(|r| r.start_line == 2);
575        assert!(modifier_fold.is_some(), "expected fold for modifier body");
576    }
577
578    #[test]
579    fn test_constructor_fold() {
580        let source = r#"
581contract Foo {
582    constructor() {
583        owner = msg.sender;
584    }
585}
586"#;
587        let ranges = folding_ranges(source);
588        let ctor_fold = ranges.iter().find(|r| r.start_line == 2);
589        assert!(ctor_fold.is_some(), "expected fold for constructor body");
590    }
591
592    #[test]
593    fn test_inner_block_comment_fold() {
594        let source = r#"
595contract Foo {
596    /*
597     * This is a comment
598     * inside a contract
599     */
600    function bar() public {}
601}
602"#;
603        let ranges = folding_ranges(source);
604        let comment_folds: Vec<_> = ranges
605            .iter()
606            .filter(|r| r.kind == Some(FoldingRangeKind::Comment))
607            .collect();
608        assert!(
609            !comment_folds.is_empty(),
610            "expected comment fold inside contract body"
611        );
612    }
613}