use color_eyre::eyre::{Result, WrapErr, eyre};
use tree_sitter::Parser;
pub fn extract_header(source: &str) -> Result<String> {
let mut parser = Parser::new();
let language = tree_sitter_solidity::LANGUAGE.into();
parser
.set_language(&language)
.wrap_err("failed to load tree-sitter Solidity grammar")?;
let tree = parser
.parse(source, None)
.ok_or_else(|| eyre!("tree-sitter returned no parse tree for the input source"))?;
let root = tree.root_node();
let mut cutoff = source.len();
for i in 0..root.named_child_count() {
let child = root.named_child(i).expect("i < named_child_count");
if matches!(
child.kind(),
"contract_declaration"
| "interface_declaration"
| "library_declaration"
| "abstract_contract_declaration"
) {
cutoff = child.start_byte();
break;
}
}
while cutoff > 0 && !source.is_char_boundary(cutoff) {
cutoff -= 1;
}
Ok(source[..cutoff].to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn picks_up_spdx_pragma_and_imports() {
let src = "// SPDX-License-Identifier: UNLICENSED\n\
pragma solidity ^0.8.24;\n\
import {Foo} from \"foo/Foo.sol\";\n\
import \"forge-std/Test.sol\";\n\
\n\
contract Test_42 is Test {\n\
function setUp() public {}\n\
}\n";
let header = extract_header(src).unwrap();
assert!(header.contains("// SPDX-License-Identifier"));
assert!(header.contains("pragma solidity"));
assert!(header.contains("import {Foo}"));
assert!(header.contains("forge-std/Test.sol"));
assert!(!header.contains("contract Test_42"));
}
#[test]
fn keeps_full_source_when_no_contract_declaration() {
let src = "pragma solidity ^0.8.0;\nimport \"a.sol\";\n";
let header = extract_header(src).unwrap();
assert_eq!(header, src);
}
#[test]
fn cuts_at_abstract_contract() {
let src = "pragma solidity ^0.8.0;\nimport \"a.sol\";\nabstract contract Foo {}\ncontract Bar {}\n";
let header = extract_header(src).unwrap();
assert!(header.contains("import \"a.sol\""));
assert!(!header.contains("abstract contract"));
assert!(!header.contains("contract Bar"));
}
}