solidity_language_server/
selection.rs1use tower_lsp::lsp_types::{Position, Range, SelectionRange};
2use tree_sitter::{Node, Parser, Point};
3
4pub fn selection_ranges(source: &str, positions: &[Position]) -> Vec<SelectionRange> {
10 let tree = match parse(source) {
11 Some(t) => t,
12 None => return positions.iter().map(|_| empty_selection_range()).collect(),
13 };
14
15 let root = tree.root_node();
16
17 positions
18 .iter()
19 .map(|pos| build_selection_range(root, source, *pos))
20 .collect()
21}
22
23fn build_selection_range(root: Node, source: &str, position: Position) -> SelectionRange {
25 let point = Point {
26 row: position.line as usize,
27 column: position.character as usize,
28 };
29
30 let leaf = match root.descendant_for_point_range(point, point) {
32 Some(n) => n,
33 None => return empty_selection_range(),
34 };
35
36 let mut ancestors = Vec::new();
38 let mut current = leaf;
39 let mut last_range: Option<Range> = None;
40
41 loop {
42 let range = node_range(current, source);
43 if last_range != Some(range) {
46 ancestors.push(range);
47 last_range = Some(range);
48 }
49 match current.parent() {
50 Some(p) => current = p,
51 None => break,
52 }
53 }
54
55 let mut result: Option<SelectionRange> = None;
57 for range in ancestors.into_iter().rev() {
58 result = Some(SelectionRange {
59 range,
60 parent: result.map(Box::new),
61 });
62 }
63
64 result.unwrap_or_else(empty_selection_range)
65}
66
67fn node_range(node: Node, source: &str) -> Range {
70 let start = node.start_position();
71 let end = node.end_position();
72 Range {
73 start: Position {
74 line: start.row as u32,
75 character: utf16_col(source, start.row, start.column),
76 },
77 end: Position {
78 line: end.row as u32,
79 character: utf16_col(source, end.row, end.column),
80 },
81 }
82}
83
84fn utf16_col(source: &str, row: usize, byte_col: usize) -> u32 {
87 let line_start = source
88 .split('\n')
89 .take(row)
90 .map(|l| l.len() + 1)
91 .sum::<usize>();
92 let slice = &source[line_start..line_start + byte_col.min(source.len() - line_start)];
93 slice.encode_utf16().count() as u32
94}
95
96fn empty_selection_range() -> SelectionRange {
98 SelectionRange {
99 range: Range::default(),
100 parent: None,
101 }
102}
103
104fn parse(source: &str) -> Option<tree_sitter::Tree> {
105 let mut parser = Parser::new();
106 parser
107 .set_language(&tree_sitter_solidity::LANGUAGE.into())
108 .ok()?;
109 parser.parse(source, None)
110}
111
112#[cfg(test)]
116mod tests {
117 use super::*;
118
119 fn chain(sr: &SelectionRange) -> Vec<Range> {
121 let mut out = vec![sr.range];
122 let mut cur = &sr.parent;
123 while let Some(p) = cur {
124 out.push(p.range);
125 cur = &p.parent;
126 }
127 out
128 }
129
130 fn depth(sr: &SelectionRange) -> usize {
132 chain(sr).len()
133 }
134
135 #[test]
140 fn test_empty_source() {
141 let results = selection_ranges("", &[Position::new(0, 0)]);
142 assert_eq!(results.len(), 1);
143 }
144
145 #[test]
146 fn test_single_identifier() {
147 let src =
148 "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\nuint256 constant X = 1;\n";
149 let results = selection_ranges(src, &[Position::new(3, 18)]);
151 assert_eq!(results.len(), 1);
152 let ranges = chain(&results[0]);
153 assert_eq!(ranges[0].start.line, ranges[0].end.line);
155 let last = ranges.last().unwrap();
157 assert_eq!(last.start.line, 0);
158 }
159
160 #[test]
161 fn test_multiple_positions() {
162 let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {\n uint256 x;\n uint256 y;\n}\n";
163 let results = selection_ranges(
164 src,
165 &[
166 Position::new(4, 12), Position::new(5, 12), ],
169 );
170 assert_eq!(results.len(), 2);
171 assert!(depth(&results[0]) >= 3);
173 assert!(depth(&results[1]) >= 3);
174 }
175
176 #[test]
177 fn test_ranges_are_nested() {
178 let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {\n function bar() public pure returns (uint256) {\n return 42;\n }\n}\n";
179 let results = selection_ranges(src, &[Position::new(5, 15)]);
181 let ranges = chain(&results[0]);
182 for i in 1..ranges.len() {
184 let inner = &ranges[i - 1];
185 let outer = &ranges[i];
186 assert!(
187 (outer.start.line < inner.start.line
188 || (outer.start.line == inner.start.line
189 && outer.start.character <= inner.start.character))
190 && (outer.end.line > inner.end.line
191 || (outer.end.line == inner.end.line
192 && outer.end.character >= inner.end.character)),
193 "range[{}] {:?} must contain range[{}] {:?}",
194 i,
195 outer,
196 i - 1,
197 inner
198 );
199 }
200 }
201
202 #[test]
203 fn test_no_duplicate_ranges() {
204 let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {\n uint256 x;\n}\n";
205 let results = selection_ranges(src, &[Position::new(4, 12)]);
206 let ranges = chain(&results[0]);
207 for i in 1..ranges.len() {
209 assert_ne!(
210 ranges[i - 1],
211 ranges[i],
212 "consecutive ranges should not be identical"
213 );
214 }
215 }
216
217 #[test]
222 fn test_function_parameter() {
223 let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {\n function add(uint256 a, uint256 b) public pure returns (uint256) {\n return a + b;\n }\n}\n";
224 let results = selection_ranges(src, &[Position::new(4, 25)]);
226 assert!(depth(&results[0]) >= 5, "depth = {}", depth(&results[0]));
228 }
229
230 #[test]
231 fn test_struct_field() {
232 let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\nstruct Point {\n uint256 x;\n uint256 y;\n}\n";
233 let results = selection_ranges(src, &[Position::new(4, 12)]);
235 let ranges = chain(&results[0]);
236 assert!(depth(&results[0]) >= 3);
237 let last = ranges.last().unwrap();
239 assert_eq!(last.start.line, 0);
240 }
241
242 #[test]
243 fn test_nested_expression() {
244 let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {\n function calc(uint256 a, uint16 b, uint16 c) internal pure returns (uint256) {\n return a + (a * b / c);\n }\n}\n";
245 let results = selection_ranges(src, &[Position::new(5, 24)]);
247 assert!(
249 depth(&results[0]) >= 8,
250 "nested expression depth = {}",
251 depth(&results[0])
252 );
253 }
254
255 #[test]
256 fn test_event_parameter() {
257 let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {\n event Transfer(address indexed from, address indexed to, uint256 value);\n}\n";
258 let results = selection_ranges(src, &[Position::new(4, 35)]);
260 assert!(depth(&results[0]) >= 3);
261 }
262
263 #[test]
264 fn test_mapping_type() {
265 let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {\n mapping(address => uint256) public balances;\n}\n";
266 let results = selection_ranges(src, &[Position::new(4, 12)]);
268 assert!(depth(&results[0]) >= 3);
269 }
270
271 #[test]
272 fn test_if_condition() {
273 let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {\n function check(uint256 x) public pure returns (bool) {\n if (x > 0) {\n return true;\n }\n return false;\n }\n}\n";
274 let results = selection_ranges(src, &[Position::new(5, 12)]);
276 assert!(depth(&results[0]) >= 6);
278 }
279
280 #[test]
281 fn test_comment_position() {
282 let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\n// This is a comment\ncontract Foo {}\n";
283 let results = selection_ranges(src, &[Position::new(3, 5)]);
285 assert!(depth(&results[0]) >= 1);
286 }
287
288 #[test]
289 fn test_pragma_position() {
290 let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {}\n";
291 let results = selection_ranges(src, &[Position::new(1, 10)]);
293 assert!(depth(&results[0]) >= 2);
294 }
295
296 #[test]
297 fn test_innermost_is_leaf() {
298 let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {\n uint256 public value;\n}\n";
299 let results = selection_ranges(src, &[Position::new(4, 19)]);
301 let ranges = chain(&results[0]);
302 let inner = &ranges[0];
304 assert_eq!(inner.start.line, inner.end.line);
305 assert_eq!(inner.end.character - inner.start.character, 5); }
307
308 #[test]
309 fn test_outermost_is_source_file() {
310 let src = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract Foo {\n uint256 x;\n}\n";
311 let results = selection_ranges(src, &[Position::new(4, 12)]);
312 let ranges = chain(&results[0]);
313 let last = ranges.last().unwrap();
314 assert_eq!(last.start.line, 0);
316 assert_eq!(last.start.character, 0);
317 }
318
319 #[test]
324 fn test_shop_sol() {
325 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("example/Shop.sol");
326 let source = std::fs::read_to_string(&path).expect("read Shop.sol");
327
328 let positions = vec![
330 Position::new(42, 15), Position::new(29, 8), Position::new(68, 22), Position::new(0, 5), ];
335 let results = selection_ranges(&source, &positions);
336 assert_eq!(results.len(), 4);
337
338 for (i, sr) in results.iter().enumerate() {
340 assert!(
341 depth(sr) >= 2,
342 "position {} should have depth >= 2, got {}",
343 i,
344 depth(sr)
345 );
346 }
347
348 for (i, sr) in results.iter().enumerate() {
350 let ranges = chain(sr);
351 for j in 1..ranges.len() {
352 let inner = &ranges[j - 1];
353 let outer = &ranges[j];
354 assert!(
355 (outer.start.line < inner.start.line
356 || (outer.start.line == inner.start.line
357 && outer.start.character <= inner.start.character))
358 && (outer.end.line > inner.end.line
359 || (outer.end.line == inner.end.line
360 && outer.end.character >= inner.end.character)),
361 "position {}: range[{}] {:?} must contain range[{}] {:?}",
362 i,
363 j,
364 outer,
365 j - 1,
366 inner
367 );
368 }
369 }
370 }
371}