Skip to main content

knowdit_sol/
header.rs

1//! "Header" extraction for a Solidity source file: every byte before
2//! the first contract / interface / library / abstract-contract
3//! declaration. By construction this captures the leading SPDX
4//! comment, `pragma` line(s), every `import` statement, and any blank
5//! lines or comments interleaved with them — exactly the slice the
6//! harness agent needs to imitate when authoring a new test file.
7//!
8//! Implemented via tree-sitter (not string scanning) so quirks like
9//! block comments, multi-line `import { ... }`, or pragma versions
10//! split across lines don't trip the cutoff.
11
12use color_eyre::eyre::{Result, WrapErr, eyre};
13use tree_sitter::Parser;
14
15/// Return every byte in `source` up to (but not including) the first
16/// contract-shaped top-level declaration. Returns the whole source
17/// when none is found (e.g. an interface-only file).
18pub fn extract_header(source: &str) -> Result<String> {
19    let mut parser = Parser::new();
20    let language = tree_sitter_solidity::LANGUAGE.into();
21    parser
22        .set_language(&language)
23        .wrap_err("failed to load tree-sitter Solidity grammar")?;
24    let tree = parser
25        .parse(source, None)
26        .ok_or_else(|| eyre!("tree-sitter returned no parse tree for the input source"))?;
27    let root = tree.root_node();
28    let mut cutoff = source.len();
29    for i in 0..root.named_child_count() {
30        // `named_child` only returns Some for indices < named_child_count
31        // by construction, so the unwrap below cannot trip.
32        let child = root.named_child(i).expect("i < named_child_count");
33        if matches!(
34            child.kind(),
35            "contract_declaration"
36                | "interface_declaration"
37                | "library_declaration"
38                | "abstract_contract_declaration"
39        ) {
40            cutoff = child.start_byte();
41            break;
42        }
43    }
44    // Clamp to a UTF-8 boundary in case tree-sitter handed back a byte
45    // offset mid-codepoint (it shouldn't, but `to_string()` panics if
46    // we slice mid-char).
47    while cutoff > 0 && !source.is_char_boundary(cutoff) {
48        cutoff -= 1;
49    }
50    Ok(source[..cutoff].to_string())
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56
57    #[test]
58    fn picks_up_spdx_pragma_and_imports() {
59        let src = "// SPDX-License-Identifier: UNLICENSED\n\
60                   pragma solidity ^0.8.24;\n\
61                   import {Foo} from \"foo/Foo.sol\";\n\
62                   import \"forge-std/Test.sol\";\n\
63                   \n\
64                   contract Test_42 is Test {\n\
65                       function setUp() public {}\n\
66                   }\n";
67        let header = extract_header(src).unwrap();
68        assert!(header.contains("// SPDX-License-Identifier"));
69        assert!(header.contains("pragma solidity"));
70        assert!(header.contains("import {Foo}"));
71        assert!(header.contains("forge-std/Test.sol"));
72        assert!(!header.contains("contract Test_42"));
73    }
74
75    #[test]
76    fn keeps_full_source_when_no_contract_declaration() {
77        let src = "pragma solidity ^0.8.0;\nimport \"a.sol\";\n";
78        let header = extract_header(src).unwrap();
79        assert_eq!(header, src);
80    }
81
82    #[test]
83    fn cuts_at_abstract_contract() {
84        let src = "pragma solidity ^0.8.0;\nimport \"a.sol\";\nabstract contract Foo {}\ncontract Bar {}\n";
85        let header = extract_header(src).unwrap();
86        assert!(header.contains("import \"a.sol\""));
87        assert!(!header.contains("abstract contract"));
88        assert!(!header.contains("contract Bar"));
89    }
90}