Skip to main content

solidity_language_server/
selection.rs

1use tower_lsp::lsp_types::{Position, Range, SelectionRange};
2use tree_sitter::{Node, Parser, Point};
3
4/// Compute selection ranges for each requested position.
5///
6/// For every position, walks up the tree-sitter node ancestry from the
7/// deepest (leaf) node to the root, producing a linked list of
8/// `SelectionRange` values that editors use for expand/shrink selection.
9pub fn selection_ranges(source: &str, positions: &[Position]) -> Vec<SelectionRange> {
10    let tree = match parse(source) {
11        Some(t) => t,
12        None => return positions.iter().map(|_| empty_selection_range()).collect(),
13    };
14
15    let root = tree.root_node();
16
17    positions
18        .iter()
19        .map(|pos| build_selection_range(root, source, *pos))
20        .collect()
21}
22
23/// Build the nested `SelectionRange` chain for a single cursor position.
24fn build_selection_range(root: Node, source: &str, position: Position) -> SelectionRange {
25    let point = Point {
26        row: position.line as usize,
27        column: position.character as usize,
28    };
29
30    // Find the deepest node at the cursor position.
31    let leaf = match root.descendant_for_point_range(point, point) {
32        Some(n) => n,
33        None => return empty_selection_range(),
34    };
35
36    // Walk up the ancestry, collecting each node with a distinct range.
37    let mut ancestors = Vec::new();
38    let mut current = leaf;
39    let mut last_range: Option<Range> = None;
40
41    loop {
42        let range = node_range(current, source);
43        // Only push if the range differs from the previous one — avoids
44        // redundant wrapper nodes that span the exact same region.
45        if last_range != Some(range) {
46            ancestors.push(range);
47            last_range = Some(range);
48        }
49        match current.parent() {
50            Some(p) => current = p,
51            None => break,
52        }
53    }
54
55    // Build the linked list from outermost (last) to innermost (first).
56    let mut result: Option<SelectionRange> = None;
57    for range in ancestors.into_iter().rev() {
58        result = Some(SelectionRange {
59            range,
60            parent: result.map(Box::new),
61        });
62    }
63
64    result.unwrap_or_else(empty_selection_range)
65}
66
67/// Convert a tree-sitter node to an LSP `Range`, accounting for UTF-16
68/// column offsets.
69fn node_range(node: Node, source: &str) -> Range {
70    let start = node.start_position();
71    let end = node.end_position();
72    Range {
73        start: Position {
74            line: start.row as u32,
75            character: utf16_col(source, start.row, start.column),
76        },
77        end: Position {
78            line: end.row as u32,
79            character: utf16_col(source, end.row, end.column),
80        },
81    }
82}
83
84/// Convert a byte-column offset to a UTF-16 code-unit offset for the given
85/// line.  Falls back to the byte column if the line is not found (ASCII).
86fn utf16_col(source: &str, row: usize, byte_col: usize) -> u32 {
87    let line_start = source
88        .split('\n')
89        .take(row)
90        .map(|l| l.len() + 1)
91        .sum::<usize>();
92    let slice = &source[line_start..line_start + byte_col.min(source.len() - line_start)];
93    slice.encode_utf16().count() as u32
94}
95
96/// A zero-width selection range at origin — used as fallback.
97fn empty_selection_range() -> SelectionRange {
98    SelectionRange {
99        range: Range::default(),
100        parent: None,
101    }
102}
103
104fn parse(source: &str) -> Option<tree_sitter::Tree> {
105    let mut parser = Parser::new();
106    parser
107        .set_language(&tree_sitter_solidity::LANGUAGE.into())
108        .ok()?;
109    parser.parse(source, None)
110}
111
112// ---------------------------------------------------------------------------
113// Tests
114// ---------------------------------------------------------------------------
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    /// Helper: return the chain of ranges as a Vec, from innermost to outermost.
120    fn chain(sr: &SelectionRange) -> Vec<Range> {
121        let mut out = vec![sr.range];
122        let mut cur = &sr.parent;
123        while let Some(p) = cur {
124            out.push(p.range);
125            cur = &p.parent;
126        }
127        out
128    }
129
130    /// Helper: number of ranges in the chain.
131    fn depth(sr: &SelectionRange) -> usize {
132        chain(sr).len()
133    }
134
135    // -----------------------------------------------------------------------
136    // Basic behaviour
137    // -----------------------------------------------------------------------
138
139    #[test]
140    fn test_empty_source() {
141        let results = selection_ranges("", &[Position::new(0, 0)]);
142        assert_eq!(results.len(), 1);
143    }
144
145    #[test]
146    fn test_single_identifier() {
147        let src =
148            "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\nuint256 constant X = 1;\n";
149        // cursor on 'X' at line 3, col 18
150        let results = selection_ranges(src, &[Position::new(3, 18)]);
151        assert_eq!(results.len(), 1);
152        let ranges = chain(&results[0]);
153        // Innermost should be the identifier itself
154        assert_eq!(ranges[0].start.line, ranges[0].end.line);
155        // Outermost should be the root (line 0..last line)
156        let last = ranges.last().unwrap();
157        assert_eq!(last.start.line, 0);
158    }
159
160    #[test]
161    fn test_multiple_positions() {
162        let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {\n    uint256 x;\n    uint256 y;\n}\n";
163        let results = selection_ranges(
164            src,
165            &[
166                Position::new(4, 12), // on 'x'
167                Position::new(5, 12), // on 'y'
168            ],
169        );
170        assert_eq!(results.len(), 2);
171        // Both should have non-trivial depth
172        assert!(depth(&results[0]) >= 3);
173        assert!(depth(&results[1]) >= 3);
174    }
175
176    #[test]
177    fn test_ranges_are_nested() {
178        let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {\n    function bar() public pure returns (uint256) {\n        return 42;\n    }\n}\n";
179        // cursor on '42' inside `return 42;`
180        let results = selection_ranges(src, &[Position::new(5, 15)]);
181        let ranges = chain(&results[0]);
182        // Each range must contain the previous one
183        for i in 1..ranges.len() {
184            let inner = &ranges[i - 1];
185            let outer = &ranges[i];
186            assert!(
187                (outer.start.line < inner.start.line
188                    || (outer.start.line == inner.start.line
189                        && outer.start.character <= inner.start.character))
190                    && (outer.end.line > inner.end.line
191                        || (outer.end.line == inner.end.line
192                            && outer.end.character >= inner.end.character)),
193                "range[{}] {:?} must contain range[{}] {:?}",
194                i,
195                outer,
196                i - 1,
197                inner
198            );
199        }
200    }
201
202    #[test]
203    fn test_no_duplicate_ranges() {
204        let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {\n    uint256 x;\n}\n";
205        let results = selection_ranges(src, &[Position::new(4, 12)]);
206        let ranges = chain(&results[0]);
207        // No two consecutive ranges should be identical
208        for i in 1..ranges.len() {
209            assert_ne!(
210                ranges[i - 1],
211                ranges[i],
212                "consecutive ranges should not be identical"
213            );
214        }
215    }
216
217    // -----------------------------------------------------------------------
218    // Specific Solidity constructs
219    // -----------------------------------------------------------------------
220
221    #[test]
222    fn test_function_parameter() {
223        let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {\n    function add(uint256 a, uint256 b) public pure returns (uint256) {\n        return a + b;\n    }\n}\n";
224        // cursor on 'a' parameter at line 4
225        let results = selection_ranges(src, &[Position::new(4, 25)]);
226        // Should walk from identifier → parameter → parameter list → function → contract body → contract → source_file
227        assert!(depth(&results[0]) >= 5, "depth = {}", depth(&results[0]));
228    }
229
230    #[test]
231    fn test_struct_field() {
232        let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\nstruct Point {\n    uint256 x;\n    uint256 y;\n}\n";
233        // cursor on 'x' at line 4
234        let results = selection_ranges(src, &[Position::new(4, 12)]);
235        let ranges = chain(&results[0]);
236        assert!(depth(&results[0]) >= 3);
237        // Outermost is source_file
238        let last = ranges.last().unwrap();
239        assert_eq!(last.start.line, 0);
240    }
241
242    #[test]
243    fn test_nested_expression() {
244        let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {\n    function calc(uint256 a, uint16 b, uint16 c) internal pure returns (uint256) {\n        return a + (a * b / c);\n    }\n}\n";
245        // cursor on 'b' inside the parenthesized expression at line 5
246        let results = selection_ranges(src, &[Position::new(5, 24)]);
247        // Should walk through: b → a * b → a * b / c → (a * b / c) → a + (...) → return stmt → function body → function → contract body → contract → source
248        assert!(
249            depth(&results[0]) >= 8,
250            "nested expression depth = {}",
251            depth(&results[0])
252        );
253    }
254
255    #[test]
256    fn test_event_parameter() {
257        let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {\n    event Transfer(address indexed from, address indexed to, uint256 value);\n}\n";
258        // cursor on 'from' at line 4
259        let results = selection_ranges(src, &[Position::new(4, 35)]);
260        assert!(depth(&results[0]) >= 3);
261    }
262
263    #[test]
264    fn test_mapping_type() {
265        let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {\n    mapping(address => uint256) public balances;\n}\n";
266        // cursor on 'address' at line 4
267        let results = selection_ranges(src, &[Position::new(4, 12)]);
268        assert!(depth(&results[0]) >= 3);
269    }
270
271    #[test]
272    fn test_if_condition() {
273        let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {\n    function check(uint256 x) public pure returns (bool) {\n        if (x > 0) {\n            return true;\n        }\n        return false;\n    }\n}\n";
274        // cursor on 'x' in the condition at line 5
275        let results = selection_ranges(src, &[Position::new(5, 12)]);
276        // Should include: x → x > 0 → if statement → function body → function → contract body → contract → source
277        assert!(depth(&results[0]) >= 6);
278    }
279
280    #[test]
281    fn test_comment_position() {
282        let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\n// This is a comment\ncontract Foo {}\n";
283        // cursor inside the comment at line 3
284        let results = selection_ranges(src, &[Position::new(3, 5)]);
285        assert!(depth(&results[0]) >= 1);
286    }
287
288    #[test]
289    fn test_pragma_position() {
290        let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {}\n";
291        // cursor on 'solidity' in pragma at line 1
292        let results = selection_ranges(src, &[Position::new(1, 10)]);
293        assert!(depth(&results[0]) >= 2);
294    }
295
296    #[test]
297    fn test_innermost_is_leaf() {
298        let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {\n    uint256 public value;\n}\n";
299        // cursor on 'value' at line 4
300        let results = selection_ranges(src, &[Position::new(4, 19)]);
301        let ranges = chain(&results[0]);
302        // The innermost range should be tight around 'value' (5 chars)
303        let inner = &ranges[0];
304        assert_eq!(inner.start.line, inner.end.line);
305        assert_eq!(inner.end.character - inner.start.character, 5); // "value"
306    }
307
308    #[test]
309    fn test_outermost_is_source_file() {
310        let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {\n    uint256 x;\n}\n";
311        let results = selection_ranges(src, &[Position::new(4, 12)]);
312        let ranges = chain(&results[0]);
313        let last = ranges.last().unwrap();
314        // Source file starts at 0,0
315        assert_eq!(last.start.line, 0);
316        assert_eq!(last.start.character, 0);
317    }
318
319    // -----------------------------------------------------------------------
320    // Integration: Shop.sol
321    // -----------------------------------------------------------------------
322
323    #[test]
324    fn test_shop_sol() {
325        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("example/Shop.sol");
326        let source = std::fs::read_to_string(&path).expect("read Shop.sol");
327
328        // Test multiple positions across the file
329        let positions = vec![
330            Position::new(42, 15), // inside addTax function body
331            Position::new(29, 8),  // struct field 'buyer'
332            Position::new(68, 22), // PRICE constant
333            Position::new(0, 5),   // comment at top
334        ];
335        let results = selection_ranges(&source, &positions);
336        assert_eq!(results.len(), 4);
337
338        // Every result should have a non-trivial chain
339        for (i, sr) in results.iter().enumerate() {
340            assert!(
341                depth(sr) >= 2,
342                "position {} should have depth >= 2, got {}",
343                i,
344                depth(sr)
345            );
346        }
347
348        // All chains should have nested (containing) ranges
349        for (i, sr) in results.iter().enumerate() {
350            let ranges = chain(sr);
351            for j in 1..ranges.len() {
352                let inner = &ranges[j - 1];
353                let outer = &ranges[j];
354                assert!(
355                    (outer.start.line < inner.start.line
356                        || (outer.start.line == inner.start.line
357                            && outer.start.character <= inner.start.character))
358                        && (outer.end.line > inner.end.line
359                            || (outer.end.line == inner.end.line
360                                && outer.end.character >= inner.end.character)),
361                    "position {}: range[{}] {:?} must contain range[{}] {:?}",
362                    i,
363                    j,
364                    outer,
365                    j - 1,
366                    inner
367                );
368            }
369        }
370    }
371}