1#![allow(deprecated)]
2
3use tower_lsp::lsp_types::{
4 DocumentSymbol, Location, Position, Range, SymbolInformation, SymbolKind, Url,
5};
6use tree_sitter::{Node, Parser};
7
8pub fn extract_document_symbols(source: &str) -> Vec<DocumentSymbol> {
12 let tree = match parse(source) {
13 Some(t) => t,
14 None => return vec![],
15 };
16 collect_top_level(tree.root_node(), source)
17}
18
19fn collect_top_level(node: Node, source: &str) -> Vec<DocumentSymbol> {
20 named_children(node)
21 .filter_map(|child| match child.kind() {
22 "pragma_directive" => Some(text_symbol(child, source, SymbolKind::STRING)),
23 "import_directive" => Some(import_symbol(child, source)),
24 "contract_declaration" => contract_symbol(child, source, SymbolKind::CLASS),
25 "interface_declaration" => contract_symbol(child, source, SymbolKind::INTERFACE),
26 "library_declaration" => contract_symbol(child, source, SymbolKind::NAMESPACE),
27 "struct_declaration" => struct_symbol(child, source),
28 "enum_declaration" => enum_symbol(child, source),
29 "function_definition" => function_symbol(child, source),
30 "event_definition" | "error_declaration" => id_symbol(child, source, SymbolKind::EVENT),
31 "state_variable_declaration" => id_symbol(child, source, SymbolKind::FIELD),
32 "user_defined_type_definition" => id_symbol(child, source, SymbolKind::TYPE_PARAMETER),
33 _ => None,
34 })
35 .collect()
36}
37
38fn collect_contract_members(body: Node, source: &str) -> Vec<DocumentSymbol> {
39 named_children(body)
40 .filter_map(|child| match child.kind() {
41 "function_definition" => function_symbol(child, source),
42 "constructor_definition" => Some(leaf("constructor", SymbolKind::CONSTRUCTOR, child)),
43 "fallback_receive_definition" => Some(leaf(
44 &fallback_or_receive(child, source),
45 SymbolKind::FUNCTION,
46 child,
47 )),
48 "state_variable_declaration" => id_symbol(child, source, SymbolKind::FIELD),
49 "event_definition" | "error_declaration" => id_symbol(child, source, SymbolKind::EVENT),
50 "modifier_definition" => id_symbol(child, source, SymbolKind::METHOD),
51 "struct_declaration" => struct_symbol(child, source),
52 "enum_declaration" => enum_symbol(child, source),
53 "using_directive" => Some(text_symbol(child, source, SymbolKind::PROPERTY)),
54 "user_defined_type_definition" => id_symbol(child, source, SymbolKind::TYPE_PARAMETER),
55 _ => None,
56 })
57 .collect()
58}
59
60fn contract_symbol(node: Node, source: &str, kind: SymbolKind) -> Option<DocumentSymbol> {
63 let name = child_id_text(node, source)?;
64 let children = find_child(node, "contract_body")
65 .map(|body| collect_contract_members(body, source))
66 .filter(|c| !c.is_empty());
67
68 Some(DocumentSymbol {
69 name: name.into(),
70 detail: None,
71 kind,
72 range: range(node),
73 selection_range: child_id_range(node)?,
74 children,
75 tags: None,
76 deprecated: None,
77 })
78}
79
80fn function_symbol(node: Node, source: &str) -> Option<DocumentSymbol> {
81 let name = child_id_text(node, source)?;
82 Some(DocumentSymbol {
83 name: name.into(),
84 detail: Some(function_detail(node, source)),
85 kind: SymbolKind::FUNCTION,
86 range: range(node),
87 selection_range: child_id_range(node)?,
88 children: None,
89 tags: None,
90 deprecated: None,
91 })
92}
93
94fn struct_symbol(node: Node, source: &str) -> Option<DocumentSymbol> {
95 let name = child_id_text(node, source)?;
96 let children = find_child(node, "struct_body")
97 .map(|body| {
98 named_children(body)
99 .filter(|c| c.kind() == "struct_member")
100 .filter_map(|c| id_symbol(c, source, SymbolKind::FIELD))
101 .collect::<Vec<_>>()
102 })
103 .filter(|c| !c.is_empty());
104
105 Some(DocumentSymbol {
106 name: name.into(),
107 detail: None,
108 kind: SymbolKind::STRUCT,
109 range: range(node),
110 selection_range: child_id_range(node)?,
111 children,
112 tags: None,
113 deprecated: None,
114 })
115}
116
117fn enum_symbol(node: Node, source: &str) -> Option<DocumentSymbol> {
118 let name = child_id_text(node, source)?;
119 let children = find_child(node, "enum_body")
120 .map(|body| {
121 named_children(body)
122 .filter(|c| c.kind() == "enum_value")
123 .map(|c| leaf(&source[c.byte_range()], SymbolKind::ENUM_MEMBER, c))
124 .collect::<Vec<_>>()
125 })
126 .filter(|c| !c.is_empty());
127
128 Some(DocumentSymbol {
129 name: name.into(),
130 detail: None,
131 kind: SymbolKind::ENUM,
132 range: range(node),
133 selection_range: child_id_range(node)?,
134 children,
135 tags: None,
136 deprecated: None,
137 })
138}
139
140fn id_symbol(node: Node, source: &str, kind: SymbolKind) -> Option<DocumentSymbol> {
142 let name = child_id_text(node, source)?;
143 Some(DocumentSymbol {
144 name: name.into(),
145 detail: None,
146 kind,
147 range: range(node),
148 selection_range: child_id_range(node).unwrap_or(range(node)),
149 children: None,
150 tags: None,
151 deprecated: None,
152 })
153}
154
155fn text_symbol(node: Node, source: &str, kind: SymbolKind) -> DocumentSymbol {
157 let text = source[node.byte_range()].trim_end_matches(';').trim();
158 leaf(text, kind, node)
159}
160
161fn import_symbol(node: Node, source: &str) -> DocumentSymbol {
162 let name = find_child(node, "string")
163 .map(|s| format!("import {}", &source[s.byte_range()]))
164 .unwrap_or_else(|| {
165 source[node.byte_range()]
166 .trim_end_matches(';')
167 .trim()
168 .into()
169 });
170 leaf(&name, SymbolKind::MODULE, node)
171}
172
173fn leaf(name: &str, kind: SymbolKind, node: Node) -> DocumentSymbol {
175 DocumentSymbol {
176 name: name.into(),
177 detail: None,
178 kind,
179 range: range(node),
180 selection_range: range(node),
181 children: None,
182 tags: None,
183 deprecated: None,
184 }
185}
186
187fn function_detail(node: Node, source: &str) -> String {
188 let params: Vec<&str> = named_children(node)
189 .filter(|c| c.kind() == "parameter")
190 .map(|c| source[c.byte_range()].trim())
191 .collect();
192
193 let returns: Vec<&str> = find_child(node, "return_type_definition")
194 .map(|ret| {
195 named_children(ret)
196 .filter(|c| c.kind() == "parameter")
197 .map(|c| source[c.byte_range()].trim())
198 .collect()
199 })
200 .unwrap_or_default();
201
202 let mut sig = format!("({})", params.join(", "));
203 if !returns.is_empty() {
204 sig.push_str(&format!(" returns ({})", returns.join(", ")));
205 }
206 sig
207}
208
209pub fn extract_workspace_symbols(files: &[(Url, String)]) -> Vec<SymbolInformation> {
213 let mut parser = Parser::new();
214 parser
215 .set_language(&tree_sitter_solidity::LANGUAGE.into())
216 .expect("failed to load Solidity grammar");
217
218 let mut symbols = Vec::new();
219 for (uri, source) in files {
220 if let Some(tree) = parser.parse(source, None) {
221 collect_workspace_symbols(tree.root_node(), source, uri, None, &mut symbols);
222 }
223 }
224 symbols
225}
226
227fn collect_workspace_symbols(
228 node: Node,
229 source: &str,
230 uri: &Url,
231 container: Option<&str>,
232 out: &mut Vec<SymbolInformation>,
233) {
234 for child in named_children(node) {
235 match child.kind() {
236 "contract_declaration" | "interface_declaration" | "library_declaration" => {
238 let kind = match child.kind() {
239 "interface_declaration" => SymbolKind::INTERFACE,
240 "library_declaration" => SymbolKind::NAMESPACE,
241 _ => SymbolKind::CLASS,
242 };
243 if let Some(name) = child_id_text(child, source) {
244 push_info(out, name, kind, child, uri, container);
245 if let Some(body) = find_child(child, "contract_body") {
246 collect_workspace_symbols(body, source, uri, Some(name), out);
247 }
248 }
249 }
250 "struct_declaration" => {
251 if let Some(name) = child_id_text(child, source) {
252 push_info(out, name, SymbolKind::STRUCT, child, uri, container);
253 if let Some(body) = find_child(child, "struct_body") {
254 collect_workspace_symbols(body, source, uri, Some(name), out);
255 }
256 }
257 }
258 "enum_declaration" => {
259 if let Some(name) = child_id_text(child, source) {
260 push_info(out, name, SymbolKind::ENUM, child, uri, container);
261 if let Some(body) = find_child(child, "enum_body") {
262 collect_workspace_symbols(body, source, uri, Some(name), out);
263 }
264 }
265 }
266 "function_definition" => {
268 push_id(out, child, source, SymbolKind::FUNCTION, uri, container)
269 }
270 "constructor_definition" => push_info(
271 out,
272 "constructor",
273 SymbolKind::CONSTRUCTOR,
274 child,
275 uri,
276 container,
277 ),
278 "state_variable_declaration" | "struct_member" => {
279 push_id(out, child, source, SymbolKind::FIELD, uri, container)
280 }
281 "event_definition" | "error_declaration" => {
282 push_id(out, child, source, SymbolKind::EVENT, uri, container)
283 }
284 "modifier_definition" => {
285 push_id(out, child, source, SymbolKind::METHOD, uri, container)
286 }
287 "enum_value" => push_info(
288 out,
289 &source[child.byte_range()],
290 SymbolKind::ENUM_MEMBER,
291 child,
292 uri,
293 container,
294 ),
295 "user_defined_type_definition" => push_id(
296 out,
297 child,
298 source,
299 SymbolKind::TYPE_PARAMETER,
300 uri,
301 container,
302 ),
303 _ => {}
304 }
305 }
306}
307
308fn push_id(
309 out: &mut Vec<SymbolInformation>,
310 node: Node,
311 source: &str,
312 kind: SymbolKind,
313 uri: &Url,
314 container: Option<&str>,
315) {
316 if let Some(name) = child_id_text(node, source) {
317 push_info(out, name, kind, node, uri, container);
318 }
319}
320
321fn push_info(
322 out: &mut Vec<SymbolInformation>,
323 name: &str,
324 kind: SymbolKind,
325 node: Node,
326 uri: &Url,
327 container: Option<&str>,
328) {
329 out.push(SymbolInformation {
330 name: name.into(),
331 kind,
332 tags: None,
333 deprecated: None,
334 location: Location {
335 uri: uri.clone(),
336 range: range(node),
337 },
338 container_name: container.map(Into::into),
339 });
340}
341
342fn parse(source: &str) -> Option<tree_sitter::Tree> {
345 let mut parser = Parser::new();
346 parser
347 .set_language(&tree_sitter_solidity::LANGUAGE.into())
348 .expect("failed to load Solidity grammar");
349 parser.parse(source, None)
350}
351
352fn range(node: Node) -> Range {
353 let s = node.start_position();
354 let e = node.end_position();
355 Range {
356 start: Position::new(s.row as u32, s.column as u32),
357 end: Position::new(e.row as u32, e.column as u32),
358 }
359}
360
361fn named_children(node: Node) -> impl Iterator<Item = Node> {
362 let mut cursor = node.walk();
363 let children: Vec<Node> = node
364 .children(&mut cursor)
365 .filter(|c| c.is_named())
366 .collect();
367 children.into_iter()
368}
369
370fn child_id_text<'a>(node: Node<'a>, source: &'a str) -> Option<&'a str> {
371 let mut cursor = node.walk();
372 node.children(&mut cursor)
373 .find(|c| c.kind() == "identifier" && c.is_named())
374 .map(|c| &source[c.byte_range()])
375}
376
377fn child_id_range(node: Node) -> Option<Range> {
378 let mut cursor = node.walk();
379 node.children(&mut cursor)
380 .find(|c| c.kind() == "identifier" && c.is_named())
381 .map(|c| range(c))
382}
383
384fn find_child<'a>(node: Node<'a>, kind: &str) -> Option<Node<'a>> {
385 let mut cursor = node.walk();
386 node.children(&mut cursor).find(|c| c.kind() == kind)
387}
388
389fn fallback_or_receive(node: Node, source: &str) -> String {
390 let mut cursor = node.walk();
391 node.children(&mut cursor)
392 .find(|c| !c.is_named() && matches!(&source[c.byte_range()], "fallback" | "receive"))
393 .map(|c| source[c.byte_range()].into())
394 .unwrap_or_else(|| "fallback".into())
395}
396
397#[cfg(test)]
400mod tests {
401 use super::*;
402
403 #[test]
404 fn test_empty_source() {
405 assert!(extract_document_symbols("").is_empty());
406 }
407
408 #[test]
409 fn test_simple_contract() {
410 let source = r#"
411pragma solidity ^0.8.0;
412
413contract Counter {
414 uint256 public count;
415 function increment() public { count += 1; }
416 function getCount() public view returns (uint256) { return count; }
417}
418"#;
419 let symbols = extract_document_symbols(source);
420 assert!(symbols.len() >= 2);
421
422 let contract = symbols
423 .iter()
424 .find(|s| s.kind == SymbolKind::CLASS)
425 .unwrap();
426 assert_eq!(contract.name, "Counter");
427
428 let children = contract.children.as_ref().unwrap();
429 assert!(
430 children
431 .iter()
432 .any(|c| c.name == "count" && c.kind == SymbolKind::FIELD)
433 );
434 assert!(
435 children
436 .iter()
437 .any(|c| c.name == "increment" && c.kind == SymbolKind::FUNCTION)
438 );
439 assert!(
440 children
441 .iter()
442 .any(|c| c.name == "getCount" && c.kind == SymbolKind::FUNCTION)
443 );
444 }
445
446 #[test]
447 fn test_struct_with_members() {
448 let source = "contract Foo { struct Info { string name; uint256 value; } }";
449 let symbols = extract_document_symbols(source);
450 let members = symbols[0]
451 .children
452 .as_ref()
453 .unwrap()
454 .iter()
455 .find(|c| c.kind == SymbolKind::STRUCT)
456 .unwrap()
457 .children
458 .as_ref()
459 .unwrap();
460 assert_eq!(members.len(), 2);
461 assert!(members.iter().any(|m| m.name == "name"));
462 assert!(members.iter().any(|m| m.name == "value"));
463 }
464
465 #[test]
466 fn test_enum_with_values() {
467 let source = "contract Foo { enum Status { Active, Paused, Stopped } }";
468 let symbols = extract_document_symbols(source);
469 let members = symbols[0]
470 .children
471 .as_ref()
472 .unwrap()
473 .iter()
474 .find(|c| c.kind == SymbolKind::ENUM)
475 .unwrap()
476 .children
477 .as_ref()
478 .unwrap();
479 assert_eq!(members.len(), 3);
480 assert!(members.iter().any(|m| m.name == "Active"));
481 assert!(members.iter().any(|m| m.name == "Paused"));
482 assert!(members.iter().any(|m| m.name == "Stopped"));
483 }
484
485 #[test]
486 fn test_all_member_types() {
487 let source = r#"
488contract Token {
489 event Transfer(address from, address to, uint256 value);
490 error Unauthorized();
491 uint256 public totalSupply;
492 modifier onlyOwner() { _; }
493 constructor() {}
494 function transfer(address to, uint256 amount) external returns (bool) { return true; }
495 fallback() external payable {}
496 receive() external payable {}
497 type Price is uint256;
498}
499"#;
500 let children = extract_document_symbols(source)
501 .into_iter()
502 .find(|s| s.kind == SymbolKind::CLASS)
503 .unwrap()
504 .children
505 .unwrap();
506
507 assert!(
508 children
509 .iter()
510 .any(|c| c.name == "Transfer" && c.kind == SymbolKind::EVENT)
511 );
512 assert!(
513 children
514 .iter()
515 .any(|c| c.name == "Unauthorized" && c.kind == SymbolKind::EVENT)
516 );
517 assert!(
518 children
519 .iter()
520 .any(|c| c.name == "totalSupply" && c.kind == SymbolKind::FIELD)
521 );
522 assert!(
523 children
524 .iter()
525 .any(|c| c.name == "onlyOwner" && c.kind == SymbolKind::METHOD)
526 );
527 assert!(
528 children
529 .iter()
530 .any(|c| c.name == "constructor" && c.kind == SymbolKind::CONSTRUCTOR)
531 );
532 assert!(
533 children
534 .iter()
535 .any(|c| c.name == "transfer" && c.kind == SymbolKind::FUNCTION)
536 );
537 assert!(
538 children
539 .iter()
540 .any(|c| c.name == "fallback" && c.kind == SymbolKind::FUNCTION)
541 );
542 assert!(
543 children
544 .iter()
545 .any(|c| c.name == "receive" && c.kind == SymbolKind::FUNCTION)
546 );
547 assert!(
548 children
549 .iter()
550 .any(|c| c.name == "Price" && c.kind == SymbolKind::TYPE_PARAMETER)
551 );
552 }
553
554 #[test]
555 fn test_interface_and_library() {
556 let source = r#"
557interface IToken { function transfer(address to, uint256 amount) external returns (bool); }
558library SafeMath { function add(uint256 a, uint256 b) internal pure returns (uint256) { return a + b; } }
559"#;
560 let symbols = extract_document_symbols(source);
561 assert!(
562 symbols
563 .iter()
564 .any(|s| s.name == "IToken" && s.kind == SymbolKind::INTERFACE)
565 );
566 assert!(
567 symbols
568 .iter()
569 .any(|s| s.name == "SafeMath" && s.kind == SymbolKind::NAMESPACE)
570 );
571 }
572
573 #[test]
574 fn test_workspace_symbols() {
575 let uri = Url::parse("file:///test.sol").unwrap();
576 let source = "contract Foo { uint256 public bar; function baz() public {} }";
577 let symbols = extract_workspace_symbols(&[(uri, source.into())]);
578 assert!(
579 symbols
580 .iter()
581 .any(|s| s.name == "Foo" && s.kind == SymbolKind::CLASS)
582 );
583 assert!(
584 symbols
585 .iter()
586 .any(|s| s.name == "bar" && s.container_name.as_deref() == Some("Foo"))
587 );
588 assert!(
589 symbols
590 .iter()
591 .any(|s| s.name == "baz" && s.container_name.as_deref() == Some("Foo"))
592 );
593 }
594
595 #[test]
596 fn test_counter_sol() {
597 let source = std::fs::read_to_string("example/Counter.sol").unwrap();
598 let symbols = extract_document_symbols(&source);
599 let children = symbols
600 .iter()
601 .find(|s| s.kind == SymbolKind::CLASS)
602 .unwrap()
603 .children
604 .as_ref()
605 .unwrap();
606 assert!(children.iter().any(|c| c.name == "increment"));
607 assert!(children.iter().any(|c| c.name == "decrement"));
608 assert!(children.iter().any(|c| c.name == "reset"));
609 assert!(children.iter().any(|c| c.name == "getCount"));
610 }
611
612 #[test]
613 fn test_function_detail() {
614 let source = "contract Foo { function bar(uint256 x, address y) public pure returns (bool) { return true; } }";
615 let func = extract_document_symbols(source)[0]
616 .children
617 .as_ref()
618 .unwrap()
619 .iter()
620 .find(|c| c.name == "bar")
621 .unwrap()
622 .clone();
623 let detail = func.detail.unwrap();
624 assert!(detail.contains("uint256 x"));
625 assert!(detail.contains("address y"));
626 assert!(detail.contains("returns"));
627 }
628}