Skip to main content

knowdit_sol/
storage.rs

1//! Tree-sitter extraction of state variables and inheritance specifiers, mirroring
2//! `cg::extract_repo_contracts_functions` but yielding storage-shaped records rather than
3//! callgraph-shaped ones. The output is project-local: contract ids match
4//! [`crate::cg::ExtractedContract::id`] and inheritance is resolved by *name within the same
5//! project*. Names that don't resolve (3rd-party imports, etc.) are dropped silently.
6
7use std::{
8    collections::{BTreeMap, HashMap},
9    path::PathBuf,
10};
11
12use color_eyre::eyre::{ContextCompat, WrapErr, ensure};
13use knowdit_repo_model::cg::FileChunk;
14use serde::Serialize;
15use tree_sitter::{Node, Parser};
16
17use crate::cg::{ExtractedContract, SolidityContractKind};
18
19/// One state variable declared inside a tree-sitter-extracted contract.
20#[derive(Debug, Clone, Serialize)]
21pub struct ExtractedStateVariable {
22    pub id: i32,
23    /// Project-local id of the declaring contract (matches `ExtractedContract.id`).
24    pub contract_id: i32,
25    /// Declaring contract name (denormalized for readability).
26    pub contract_name: String,
27    pub name: String,
28    /// Stringified Solidity type as it appears in source (mappings/structs included verbatim).
29    pub type_text: String,
30    /// True for `constant` declarations; we still record them but the agent should skip them as
31    /// "no storage" rows.
32    pub is_constant: bool,
33    /// True for `immutable` declarations; immutables are stored in code, not storage, but the
34    /// agent typically treats them as state variables for read/write purposes.
35    pub is_immutable: bool,
36    pub relative_file_path: PathBuf,
37    pub chunk: FileChunk,
38}
39
40/// One direct `is`-spec edge captured per contract. `parent_name` is the textual reference;
41/// `parent_contract_id` is `None` if it doesn't resolve to any project-local contract.
42#[derive(Debug, Clone, Serialize)]
43pub struct ExtractedInheritance {
44    pub contract_id: i32,
45    pub contract_name: String,
46    pub parent_name: String,
47    pub parent_contract_id: Option<i32>,
48}
49
50/// Bundled output of the tree-sitter storage extractor.
51#[derive(Debug, Clone, Serialize, Default)]
52pub struct StorageExtractionResult {
53    pub state_variables: Vec<ExtractedStateVariable>,
54    pub inherits: Vec<ExtractedInheritance>,
55}
56
57/// Walk the same .sol files used by the call-graph extractor and pull out state variables and
58/// inheritance specifiers. The caller is expected to have already parsed/extracted contracts
59/// via [`crate::cg::extract_repo_contracts_functions`]; we reuse the (id, name) mapping so the
60/// project-local ids stay consistent across the two passes.
61pub async fn extract_state_variables_and_inheritance(
62    repo_root: &std::path::Path,
63    contracts: &[ExtractedContract],
64) -> Result<StorageExtractionResult, color_eyre::Report> {
65    let mut parser = Parser::new();
66    let language = tree_sitter_solidity::LANGUAGE.into();
67    parser
68        .set_language(&language)
69        .wrap_err("failed to load tree-sitter Solidity grammar")?;
70
71    // Group contract ids by their (relative_file_path, name) so we can re-attach state vars
72    // from a fresh tree-sitter pass. Names should be unique within a single .sol file in
73    // well-formed projects; if there's a clash we take the first-seen contract id.
74    let mut id_by_file_and_name: HashMap<(PathBuf, String), i32> = HashMap::new();
75    let mut name_to_ids: BTreeMap<String, Vec<i32>> = BTreeMap::new();
76    for c in contracts {
77        id_by_file_and_name
78            .entry((c.relative_file_path.clone(), c.name.clone()))
79            .or_insert(c.id);
80        name_to_ids.entry(c.name.clone()).or_default().push(c.id);
81    }
82    // De-dup file lists so we parse each file once.
83    let mut files_to_parse: Vec<PathBuf> = contracts
84        .iter()
85        .map(|c| c.relative_file_path.clone())
86        .collect();
87    files_to_parse.sort();
88    files_to_parse.dedup();
89
90    let mut state_variables = Vec::new();
91    let mut inherits = Vec::new();
92    let mut next_state_var_id = 1i32;
93
94    for relative in &files_to_parse {
95        let absolute = repo_root.join(relative);
96        let source = tokio::fs::read_to_string(&absolute)
97            .await
98            .wrap_err_with(|| format!("failed to read Solidity file {}", absolute.display()))?;
99        let tree = parser
100            .parse(source.as_bytes(), None)
101            .wrap_err_with(|| format!("failed to parse Solidity file {}", absolute.display()))?;
102        let root = tree.root_node();
103
104        let mut contract_nodes = Vec::new();
105        crate::node::collect_contract_nodes_pub(root, &mut contract_nodes);
106
107        for cn in contract_nodes {
108            if SolidityContractKind::from_node_kind(cn.kind()).is_none() {
109                continue;
110            }
111            let Some(name_node) = cn.child_by_field_name("name") else {
112                continue;
113            };
114            let contract_name = name_node
115                .utf8_text(source.as_bytes())
116                .ok()
117                .map(|s| s.to_string())
118                .unwrap_or_default();
119            let Some(&contract_id) =
120                id_by_file_and_name.get(&(relative.clone(), contract_name.clone()))
121            else {
122                continue;
123            };
124
125            // Inheritance specifiers
126            let mut cursor = cn.walk();
127            for child in cn.named_children(&mut cursor) {
128                if child.kind() == "inheritance_specifier" {
129                    let parent_name = parent_name_of_inheritance_specifier(child, &source);
130                    let Some(parent_name) = parent_name else {
131                        continue;
132                    };
133                    // Resolve by name. If multiple contracts share the name, keep them all (one
134                    // row per resolution); typically there's one.
135                    let parent_ids = name_to_ids.get(&parent_name).cloned().unwrap_or_default();
136                    if parent_ids.is_empty() {
137                        inherits.push(ExtractedInheritance {
138                            contract_id,
139                            contract_name: contract_name.clone(),
140                            parent_name,
141                            parent_contract_id: None,
142                        });
143                    } else {
144                        for pid in parent_ids {
145                            if pid == contract_id {
146                                continue;
147                            }
148                            inherits.push(ExtractedInheritance {
149                                contract_id,
150                                contract_name: contract_name.clone(),
151                                parent_name: parent_name.clone(),
152                                parent_contract_id: Some(pid),
153                            });
154                        }
155                    }
156                }
157            }
158
159            // State variables
160            if let Some(body) = cn.child_by_field_name("body") {
161                let mut cursor = body.walk();
162                for child in body.named_children(&mut cursor) {
163                    if child.kind() != "state_variable_declaration" {
164                        continue;
165                    }
166                    let Some(state_var) = parse_state_variable(
167                        child,
168                        &source,
169                        next_state_var_id,
170                        contract_id,
171                        contract_name.clone(),
172                        relative.clone(),
173                    )?
174                    else {
175                        continue;
176                    };
177                    next_state_var_id += 1;
178                    state_variables.push(state_var);
179                }
180            }
181        }
182    }
183
184    Ok(StorageExtractionResult {
185        state_variables,
186        inherits,
187    })
188}
189
190fn parent_name_of_inheritance_specifier(node: Node<'_>, source: &str) -> Option<String> {
191    let mut cursor = node.walk();
192    for child in node.named_children(&mut cursor) {
193        if child.kind() == "user_defined_type" {
194            let mut cur2 = child.walk();
195            // Take the last identifier — for `Foo.Bar` it's the inner name; if a single
196            // identifier, that's the parent name.
197            let mut last_id: Option<Node<'_>> = None;
198            for c in child.named_children(&mut cur2) {
199                if c.kind() == "identifier" {
200                    last_id = Some(c);
201                }
202            }
203            if let Some(id_node) = last_id
204                && let Ok(text) = id_node.utf8_text(source.as_bytes())
205            {
206                return Some(text.to_string());
207            }
208        } else if child.kind() == "identifier"
209            && let Ok(text) = child.utf8_text(source.as_bytes())
210        {
211            return Some(text.to_string());
212        }
213    }
214    None
215}
216
217fn parse_state_variable(
218    node: Node<'_>,
219    source: &str,
220    id: i32,
221    contract_id: i32,
222    contract_name: String,
223    relative_file_path: PathBuf,
224) -> Result<Option<ExtractedStateVariable>, color_eyre::Report> {
225    let raw = node
226        .utf8_text(source.as_bytes())
227        .wrap_err("state_variable_declaration is not utf-8")?;
228
229    // Name: last named identifier child (after type_name). Skip identifiers nested inside
230    // type_name (e.g. user-defined struct name).
231    let mut name: Option<String> = None;
232    let mut type_text: Option<String> = None;
233    let mut cursor = node.walk();
234    for child in node.named_children(&mut cursor) {
235        match child.kind() {
236            "type_name" => {
237                let text = child
238                    .utf8_text(source.as_bytes())
239                    .wrap_err("type_name not utf-8")?;
240                type_text = Some(text.trim().to_string());
241            }
242            "identifier" => {
243                let text = child
244                    .utf8_text(source.as_bytes())
245                    .wrap_err("identifier not utf-8")?;
246                name = Some(text.to_string());
247            }
248            _ => {}
249        }
250    }
251    let Some(name) = name else {
252        return Ok(None);
253    };
254    let type_text = type_text.unwrap_or_default();
255
256    let is_constant = contains_keyword(raw, "constant");
257    let is_immutable = contains_keyword(raw, "immutable");
258
259    ensure!(
260        source.is_char_boundary(node.start_byte()) && source.is_char_boundary(node.end_byte()),
261        "tree-sitter produced non-UTF-8-boundary state variable range {}..{}",
262        node.start_byte(),
263        node.end_byte()
264    );
265    let chunk = crate::node::node_chunk_pub(node, source)?;
266
267    Ok(Some(ExtractedStateVariable {
268        id,
269        contract_id,
270        contract_name,
271        name,
272        type_text,
273        is_constant,
274        is_immutable,
275        relative_file_path,
276        chunk,
277    }))
278}
279
280fn contains_keyword(haystack: &str, kw: &str) -> bool {
281    let mut idx = 0usize;
282    while let Some(found) = haystack[idx..].find(kw) {
283        let pos = idx + found;
284        let before = pos
285            .checked_sub(1)
286            .and_then(|p| haystack.as_bytes().get(p))
287            .copied();
288        let after = haystack.as_bytes().get(pos + kw.len()).copied();
289        let lhs_ok = before.is_none_or(|b| !is_id_byte(b));
290        let rhs_ok = after.is_none_or(|b| !is_id_byte(b));
291        if lhs_ok && rhs_ok {
292            return true;
293        }
294        idx = pos + kw.len();
295    }
296    false
297}
298
299fn is_id_byte(b: u8) -> bool {
300    b.is_ascii_alphanumeric() || b == b'_'
301}