1use color_eyre::eyre::{Result, WrapErr, eyre};
13use tree_sitter::Parser;
14
15pub fn extract_header(source: &str) -> Result<String> {
19 let mut parser = Parser::new();
20 let language = tree_sitter_solidity::LANGUAGE.into();
21 parser
22 .set_language(&language)
23 .wrap_err("failed to load tree-sitter Solidity grammar")?;
24 let tree = parser
25 .parse(source, None)
26 .ok_or_else(|| eyre!("tree-sitter returned no parse tree for the input source"))?;
27 let root = tree.root_node();
28 let mut cutoff = source.len();
29 for i in 0..root.named_child_count() {
30 let child = root.named_child(i).expect("i < named_child_count");
33 if matches!(
34 child.kind(),
35 "contract_declaration"
36 | "interface_declaration"
37 | "library_declaration"
38 | "abstract_contract_declaration"
39 ) {
40 cutoff = child.start_byte();
41 break;
42 }
43 }
44 while cutoff > 0 && !source.is_char_boundary(cutoff) {
48 cutoff -= 1;
49 }
50 Ok(source[..cutoff].to_string())
51}
52
53#[cfg(test)]
54mod tests {
55 use super::*;
56
57 #[test]
58 fn picks_up_spdx_pragma_and_imports() {
59 let src = "// SPDX-License-Identifier: UNLICENSED\n\
60 pragma solidity ^0.8.24;\n\
61 import {Foo} from \"foo/Foo.sol\";\n\
62 import \"forge-std/Test.sol\";\n\
63 \n\
64 contract Test_42 is Test {\n\
65 function setUp() public {}\n\
66 }\n";
67 let header = extract_header(src).unwrap();
68 assert!(header.contains("// SPDX-License-Identifier"));
69 assert!(header.contains("pragma solidity"));
70 assert!(header.contains("import {Foo}"));
71 assert!(header.contains("forge-std/Test.sol"));
72 assert!(!header.contains("contract Test_42"));
73 }
74
75 #[test]
76 fn keeps_full_source_when_no_contract_declaration() {
77 let src = "pragma solidity ^0.8.0;\nimport \"a.sol\";\n";
78 let header = extract_header(src).unwrap();
79 assert_eq!(header, src);
80 }
81
82 #[test]
83 fn cuts_at_abstract_contract() {
84 let src = "pragma solidity ^0.8.0;\nimport \"a.sol\";\nabstract contract Foo {}\ncontract Bar {}\n";
85 let header = extract_header(src).unwrap();
86 assert!(header.contains("import \"a.sol\""));
87 assert!(!header.contains("abstract contract"));
88 assert!(!header.contains("contract Bar"));
89 }
90}