1use std::{
8 collections::{BTreeMap, HashMap},
9 path::PathBuf,
10};
11
12use color_eyre::eyre::{ContextCompat, WrapErr, ensure};
13use knowdit_repo_model::cg::FileChunk;
14use serde::Serialize;
15use tree_sitter::{Node, Parser};
16
17use crate::cg::{ExtractedContract, SolidityContractKind};
18
19#[derive(Debug, Clone, Serialize)]
21pub struct ExtractedStateVariable {
22 pub id: i32,
23 pub contract_id: i32,
25 pub contract_name: String,
27 pub name: String,
28 pub type_text: String,
30 pub is_constant: bool,
33 pub is_immutable: bool,
36 pub relative_file_path: PathBuf,
37 pub chunk: FileChunk,
38}
39
40#[derive(Debug, Clone, Serialize)]
43pub struct ExtractedInheritance {
44 pub contract_id: i32,
45 pub contract_name: String,
46 pub parent_name: String,
47 pub parent_contract_id: Option<i32>,
48}
49
50#[derive(Debug, Clone, Serialize, Default)]
52pub struct StorageExtractionResult {
53 pub state_variables: Vec<ExtractedStateVariable>,
54 pub inherits: Vec<ExtractedInheritance>,
55}
56
57pub async fn extract_state_variables_and_inheritance(
62 repo_root: &std::path::Path,
63 contracts: &[ExtractedContract],
64) -> Result<StorageExtractionResult, color_eyre::Report> {
65 let mut parser = Parser::new();
66 let language = tree_sitter_solidity::LANGUAGE.into();
67 parser
68 .set_language(&language)
69 .wrap_err("failed to load tree-sitter Solidity grammar")?;
70
71 let mut id_by_file_and_name: HashMap<(PathBuf, String), i32> = HashMap::new();
75 let mut name_to_ids: BTreeMap<String, Vec<i32>> = BTreeMap::new();
76 for c in contracts {
77 id_by_file_and_name
78 .entry((c.relative_file_path.clone(), c.name.clone()))
79 .or_insert(c.id);
80 name_to_ids.entry(c.name.clone()).or_default().push(c.id);
81 }
82 let mut files_to_parse: Vec<PathBuf> = contracts
84 .iter()
85 .map(|c| c.relative_file_path.clone())
86 .collect();
87 files_to_parse.sort();
88 files_to_parse.dedup();
89
90 let mut state_variables = Vec::new();
91 let mut inherits = Vec::new();
92 let mut next_state_var_id = 1i32;
93
94 for relative in &files_to_parse {
95 let absolute = repo_root.join(relative);
96 let source = tokio::fs::read_to_string(&absolute)
97 .await
98 .wrap_err_with(|| format!("failed to read Solidity file {}", absolute.display()))?;
99 let tree = parser
100 .parse(source.as_bytes(), None)
101 .wrap_err_with(|| format!("failed to parse Solidity file {}", absolute.display()))?;
102 let root = tree.root_node();
103
104 let mut contract_nodes = Vec::new();
105 crate::node::collect_contract_nodes_pub(root, &mut contract_nodes);
106
107 for cn in contract_nodes {
108 if SolidityContractKind::from_node_kind(cn.kind()).is_none() {
109 continue;
110 }
111 let Some(name_node) = cn.child_by_field_name("name") else {
112 continue;
113 };
114 let contract_name = name_node
115 .utf8_text(source.as_bytes())
116 .ok()
117 .map(|s| s.to_string())
118 .unwrap_or_default();
119 let Some(&contract_id) =
120 id_by_file_and_name.get(&(relative.clone(), contract_name.clone()))
121 else {
122 continue;
123 };
124
125 let mut cursor = cn.walk();
127 for child in cn.named_children(&mut cursor) {
128 if child.kind() == "inheritance_specifier" {
129 let parent_name = parent_name_of_inheritance_specifier(child, &source);
130 let Some(parent_name) = parent_name else {
131 continue;
132 };
133 let parent_ids = name_to_ids.get(&parent_name).cloned().unwrap_or_default();
136 if parent_ids.is_empty() {
137 inherits.push(ExtractedInheritance {
138 contract_id,
139 contract_name: contract_name.clone(),
140 parent_name,
141 parent_contract_id: None,
142 });
143 } else {
144 for pid in parent_ids {
145 if pid == contract_id {
146 continue;
147 }
148 inherits.push(ExtractedInheritance {
149 contract_id,
150 contract_name: contract_name.clone(),
151 parent_name: parent_name.clone(),
152 parent_contract_id: Some(pid),
153 });
154 }
155 }
156 }
157 }
158
159 if let Some(body) = cn.child_by_field_name("body") {
161 let mut cursor = body.walk();
162 for child in body.named_children(&mut cursor) {
163 if child.kind() != "state_variable_declaration" {
164 continue;
165 }
166 let Some(state_var) = parse_state_variable(
167 child,
168 &source,
169 next_state_var_id,
170 contract_id,
171 contract_name.clone(),
172 relative.clone(),
173 )?
174 else {
175 continue;
176 };
177 next_state_var_id += 1;
178 state_variables.push(state_var);
179 }
180 }
181 }
182 }
183
184 Ok(StorageExtractionResult {
185 state_variables,
186 inherits,
187 })
188}
189
190fn parent_name_of_inheritance_specifier(node: Node<'_>, source: &str) -> Option<String> {
191 let mut cursor = node.walk();
192 for child in node.named_children(&mut cursor) {
193 if child.kind() == "user_defined_type" {
194 let mut cur2 = child.walk();
195 let mut last_id: Option<Node<'_>> = None;
198 for c in child.named_children(&mut cur2) {
199 if c.kind() == "identifier" {
200 last_id = Some(c);
201 }
202 }
203 if let Some(id_node) = last_id
204 && let Ok(text) = id_node.utf8_text(source.as_bytes())
205 {
206 return Some(text.to_string());
207 }
208 } else if child.kind() == "identifier"
209 && let Ok(text) = child.utf8_text(source.as_bytes())
210 {
211 return Some(text.to_string());
212 }
213 }
214 None
215}
216
217fn parse_state_variable(
218 node: Node<'_>,
219 source: &str,
220 id: i32,
221 contract_id: i32,
222 contract_name: String,
223 relative_file_path: PathBuf,
224) -> Result<Option<ExtractedStateVariable>, color_eyre::Report> {
225 let raw = node
226 .utf8_text(source.as_bytes())
227 .wrap_err("state_variable_declaration is not utf-8")?;
228
229 let mut name: Option<String> = None;
232 let mut type_text: Option<String> = None;
233 let mut cursor = node.walk();
234 for child in node.named_children(&mut cursor) {
235 match child.kind() {
236 "type_name" => {
237 let text = child
238 .utf8_text(source.as_bytes())
239 .wrap_err("type_name not utf-8")?;
240 type_text = Some(text.trim().to_string());
241 }
242 "identifier" => {
243 let text = child
244 .utf8_text(source.as_bytes())
245 .wrap_err("identifier not utf-8")?;
246 name = Some(text.to_string());
247 }
248 _ => {}
249 }
250 }
251 let Some(name) = name else {
252 return Ok(None);
253 };
254 let type_text = type_text.unwrap_or_default();
255
256 let is_constant = contains_keyword(raw, "constant");
257 let is_immutable = contains_keyword(raw, "immutable");
258
259 ensure!(
260 source.is_char_boundary(node.start_byte()) && source.is_char_boundary(node.end_byte()),
261 "tree-sitter produced non-UTF-8-boundary state variable range {}..{}",
262 node.start_byte(),
263 node.end_byte()
264 );
265 let chunk = crate::node::node_chunk_pub(node, source)?;
266
267 Ok(Some(ExtractedStateVariable {
268 id,
269 contract_id,
270 contract_name,
271 name,
272 type_text,
273 is_constant,
274 is_immutable,
275 relative_file_path,
276 chunk,
277 }))
278}
279
280fn contains_keyword(haystack: &str, kw: &str) -> bool {
281 let mut idx = 0usize;
282 while let Some(found) = haystack[idx..].find(kw) {
283 let pos = idx + found;
284 let before = pos
285 .checked_sub(1)
286 .and_then(|p| haystack.as_bytes().get(p))
287 .copied();
288 let after = haystack.as_bytes().get(pos + kw.len()).copied();
289 let lhs_ok = before.is_none_or(|b| !is_id_byte(b));
290 let rhs_ok = after.is_none_or(|b| !is_id_byte(b));
291 if lhs_ok && rhs_ok {
292 return true;
293 }
294 idx = pos + kw.len();
295 }
296 false
297}
298
299fn is_id_byte(b: u8) -> bool {
300 b.is_ascii_alphanumeric() || b == b'_'
301}