1use std::{
2 collections::{BTreeMap, HashMap, HashSet, VecDeque},
3 fmt::Write as _,
4 sync::Arc,
5};
6
7use color_eyre::eyre::{ContextCompat, WrapErr, ensure};
8use knowdit_repo_model::{
9 inheritance::{ContractInherit, InheritanceGraph},
10 storage::{ContractVariable, FunctionStateVariable, StateVariable, StorageGraph},
11};
12use knowdit_sol::{
13 cg::{
14 ExtractedContract, SolidityCallableKind, SolidityExtractionConfig,
15 extract_repo_contracts_functions,
16 },
17 storage::{
18 ExtractedInheritance, ExtractedStateVariable, extract_state_variables_and_inheritance,
19 },
20};
21use llmy::{
22 agent::{LLMYError, StepResult, tool::ToolBox},
23 client::{client::LLM, settings::LLMSettings},
24 harness::Agent,
25};
26use schemars::JsonSchema;
27use serde::Deserialize;
28use tokio::sync::Mutex;
29
30use crate::storage_prompt::{
31 StorageAgentPromptInput, storage_agent_system_prompt, storage_agent_user_prompt,
32};
33
34#[derive(Debug, Clone)]
35pub struct StorageAgentConfig {
36 pub extraction: SolidityExtractionConfig,
37 pub max_agent_steps: usize,
38 pub cache_key: String,
39 pub debug_prefix: Option<String>,
40 pub llm_settings: Option<LLMSettings>,
41}
42
43impl StorageAgentConfig {
44 pub fn new(
45 extraction: SolidityExtractionConfig,
46 max_agent_steps: usize,
47 cache_key: impl Into<String>,
48 debug_prefix: Option<String>,
49 llm_settings: Option<LLMSettings>,
50 ) -> Self {
51 Self {
52 extraction,
53 max_agent_steps,
54 cache_key: cache_key.into(),
55 debug_prefix,
56 llm_settings,
57 }
58 }
59}
60
61#[derive(Debug, Clone)]
62pub struct StorageAgentResult {
63 pub storage_graph: StorageGraph,
64 pub inheritance_graph: InheritanceGraph,
65 pub extracted_contracts: Vec<ExtractedContract>,
69 pub steps: usize,
70 pub final_response: String,
71}
72
73pub async fn analyze_repo_storage(
74 llm: &LLM,
75 config: StorageAgentConfig,
76) -> Result<StorageAgentResult, color_eyre::Report> {
77 let extraction = extract_repo_contracts_functions(&config.extraction).await?;
78 let storage_extraction =
79 extract_state_variables_and_inheritance(&extraction.repo_root, &extraction.contracts)
80 .await?;
81
82 let extracted_contracts = extraction.contracts.clone();
83 let runner = StorageRunner::new(extraction.contracts, storage_extraction);
84 let mut result = runner.run(llm, &config).await?;
85 result.extracted_contracts = extracted_contracts;
86 Ok(result)
87}
88
89struct StorageRunner {
90 contract_by_id: BTreeMap<i32, ExtractedContract>,
91 state_variables: Vec<ExtractedStateVariable>,
92 state_variable_by_id: BTreeMap<i32, ExtractedStateVariable>,
93 contract_inherits: Vec<ContractInherit>,
95 raw_inherits: Vec<ExtractedInheritance>,
97 visible_state_var_ids: BTreeMap<i32, Vec<i32>>,
99}
100
101impl StorageRunner {
102 fn new(
103 contracts: Vec<ExtractedContract>,
104 storage_extraction: knowdit_sol::storage::StorageExtractionResult,
105 ) -> Self {
106 let contract_by_id = contracts
107 .iter()
108 .cloned()
109 .map(|c| (c.id, c))
110 .collect::<BTreeMap<_, _>>();
111 let state_variable_by_id = storage_extraction
112 .state_variables
113 .iter()
114 .cloned()
115 .map(|v| (v.id, v))
116 .collect::<BTreeMap<_, _>>();
117
118 let mut contract_inherits: Vec<ContractInherit> = Vec::new();
119 for inh in &storage_extraction.inherits {
120 if let Some(parent_id) = inh.parent_contract_id {
121 contract_inherits.push(ContractInherit {
122 contract_id: inh.contract_id,
123 inherited_id: parent_id,
124 });
125 }
126 }
127 contract_inherits.sort_by_key(|i| (i.contract_id, i.inherited_id));
128 contract_inherits.dedup_by_key(|i| (i.contract_id, i.inherited_id));
129
130 let mut parents: BTreeMap<i32, Vec<i32>> = BTreeMap::new();
132 for ci in &contract_inherits {
133 parents
134 .entry(ci.contract_id)
135 .or_default()
136 .push(ci.inherited_id);
137 }
138
139 let mut own_state_var_ids: BTreeMap<i32, Vec<i32>> = BTreeMap::new();
141 for sv in &storage_extraction.state_variables {
142 own_state_var_ids
143 .entry(sv.contract_id)
144 .or_default()
145 .push(sv.id);
146 }
147
148 let mut visible_state_var_ids: BTreeMap<i32, Vec<i32>> = BTreeMap::new();
150 for c in &contracts {
151 let mut visited: HashSet<i32> = HashSet::new();
152 let mut queue: VecDeque<i32> = VecDeque::new();
153 queue.push_back(c.id);
154 let mut visible: Vec<i32> = Vec::new();
155 while let Some(cid) = queue.pop_front() {
156 if !visited.insert(cid) {
157 continue;
158 }
159 if let Some(svs) = own_state_var_ids.get(&cid) {
160 visible.extend(svs.iter().copied());
161 }
162 if let Some(ps) = parents.get(&cid) {
163 for &p in ps {
164 if !visited.contains(&p) {
165 queue.push_back(p);
166 }
167 }
168 }
169 }
170 visible.sort();
171 visible.dedup();
172 visible_state_var_ids.insert(c.id, visible);
173 }
174
175 Self {
176 contract_by_id,
177 state_variables: storage_extraction.state_variables.clone(),
178 state_variable_by_id,
179 contract_inherits,
180 raw_inherits: storage_extraction.inherits,
181 visible_state_var_ids,
182 }
183 }
184
185 async fn run(
186 self,
187 llm: &LLM,
188 config: &StorageAgentConfig,
189 ) -> Result<StorageAgentResult, color_eyre::Report> {
190 let store = SharedStorageEdges::default();
191 let mut summaries: Vec<String> = Vec::new();
192 let mut total_steps = 0usize;
193
194 let mut contract_ids: Vec<i32> = self.contract_by_id.keys().copied().collect();
196 contract_ids.sort();
197
198 for cid in contract_ids {
199 let contract = self
200 .contract_by_id
201 .get(&cid)
202 .cloned()
203 .wrap_err("missing contract")?;
204 let has_functions = contract.functions.iter().any(|f| {
206 matches!(
207 f.kind,
208 SolidityCallableKind::Function
209 | SolidityCallableKind::Constructor
210 | SolidityCallableKind::Receive
211 | SolidityCallableKind::Fallback
212 | SolidityCallableKind::Modifier
213 )
214 });
215 let visible_svs = self
216 .visible_state_var_ids
217 .get(&cid)
218 .cloned()
219 .unwrap_or_default();
220 let visible_svs: Vec<i32> = visible_svs
222 .into_iter()
223 .filter(|sv_id| {
224 self.state_variable_by_id
225 .get(sv_id)
226 .map(|sv| !sv.is_constant)
227 .unwrap_or(false)
228 })
229 .collect();
230 if !has_functions || visible_svs.is_empty() {
231 continue;
232 }
233
234 let function_index_table = self.render_function_index(&contract);
235 let state_variable_index_table = self.render_state_var_index(&visible_svs);
236 let inheritance_chain = self.render_inheritance_chain(cid);
237 let contract_source =
238 render_with_line_numbers(&contract.chunk.content, contract.chunk.loc.start_line);
239
240 let prompt_input = StorageAgentPromptInput {
241 contract_label: format!(
242 "{} {} ({})",
243 contract.kind.as_str(),
244 contract.name,
245 contract.relative_file_path.display()
246 ),
247 contract_source,
248 function_index_table,
249 state_variable_index_table,
250 inheritance_chain,
251 };
252
253 let mut tools = ToolBox::new();
254 tools.add_tool(RecordStateVarAccessTool::new(
255 contract
256 .functions
257 .iter()
258 .map(|f| f.id)
259 .collect::<HashSet<_>>(),
260 visible_svs.iter().copied().collect::<HashSet<_>>(),
261 store.clone(),
262 ));
263
264 let cache_key = format!("{}-c{}", config.cache_key, contract.id);
265 let mut agent = Agent::new(storage_agent_system_prompt(), tools, cache_key);
266 let user = storage_agent_user_prompt(prompt_input);
267 let mut steps = 1;
268 let mut step_result = agent
269 .step_with_user(
270 user,
271 llm,
272 config.debug_prefix.as_deref(),
273 config.llm_settings.clone(),
274 )
275 .await
276 .wrap_err_with(|| {
277 format!(
278 "storage agent initial step failed for contract {}",
279 contract.name
280 )
281 })?;
282
283 loop {
284 if let StepResult::Stop(summary) = &step_result {
285 summaries.push(format!(
286 "{}::{}",
287 contract.relative_file_path.display(),
288 contract.name
289 ));
290 summaries.push(summary.trim().to_string());
291 break;
292 }
293 ensure!(
294 steps < config.max_agent_steps,
295 "storage agent exceeded max_agent_steps={} on contract {}",
296 config.max_agent_steps,
297 contract.name
298 );
299 steps += 1;
300 step_result = agent
301 .step(
302 llm,
303 config.debug_prefix.as_deref(),
304 config.llm_settings.clone(),
305 )
306 .await
307 .wrap_err_with(|| {
308 format!(
309 "storage agent step {} failed for contract {}",
310 steps, contract.name
311 )
312 })?;
313 }
314 total_steps += steps;
315 }
316
317 let mut state_variables_map: BTreeMap<i32, StateVariable> = BTreeMap::new();
319 for sv in &self.state_variables {
320 state_variables_map.insert(
321 sv.id,
322 StateVariable {
323 id: sv.id,
324 name: sv.name.clone(),
325 type_name: sv.type_text.clone(),
326 relative_file_path: sv.relative_file_path.clone(),
327 loc: sv.chunk.loc,
328 content: sv.chunk.content.clone(),
329 },
330 );
331 }
332 let mut contract_variables: Vec<ContractVariable> = self
333 .state_variables
334 .iter()
335 .map(|sv| ContractVariable {
336 contract_id: sv.contract_id,
337 state_variable_id: sv.id,
338 description: None,
339 })
340 .collect();
341 contract_variables.sort_by_key(|cv| (cv.contract_id, cv.state_variable_id));
342 let recorded = store.snapshot().await;
343 let storage_graph = StorageGraph {
344 state_variables: state_variables_map,
345 contract_variables,
346 function_state_variables: recorded,
347 };
348 let inheritance_graph = InheritanceGraph::new(self.contract_inherits.clone());
349
350 Ok(StorageAgentResult {
351 storage_graph,
352 inheritance_graph,
353 extracted_contracts: Vec::new(),
354 steps: total_steps,
355 final_response: summaries.join("\n"),
356 })
357 }
358
359 fn render_function_index(&self, contract: &ExtractedContract) -> String {
360 let mut s = String::new();
361 for f in &contract.functions {
362 let _ = writeln!(
363 s,
364 "- function_id={} kind={} name={} args=({}) line={}",
365 f.id,
366 f.kind.as_str(),
367 f.name,
368 f.args,
369 f.chunk.loc.start_line
370 );
371 }
372 if s.is_empty() {
373 "(no functions)".to_string()
374 } else {
375 s
376 }
377 }
378
379 fn render_state_var_index(&self, ids: &[i32]) -> String {
380 let mut s = String::new();
381 for sv_id in ids {
382 let Some(sv) = self.state_variable_by_id.get(sv_id) else {
383 continue;
384 };
385 let imm = if sv.is_immutable { " [immutable]" } else { "" };
386 let _ = writeln!(
387 s,
388 "- state_variable_id={} contract={} name={} type={}{}",
389 sv.id, sv.contract_name, sv.name, sv.type_text, imm
390 );
391 }
392 if s.is_empty() {
393 "(none — contract has no visible state variables)".to_string()
394 } else {
395 s
396 }
397 }
398
399 fn render_inheritance_chain(&self, contract_id: i32) -> String {
400 let mut chain: Vec<String> = Vec::new();
401 for inh in &self.raw_inherits {
402 if inh.contract_id != contract_id {
403 continue;
404 }
405 if let Some(pid) = inh.parent_contract_id {
406 let parent = self.contract_by_id.get(&pid);
407 let label = parent
408 .map(|p| format!("{} ({})", p.name, p.relative_file_path.display()))
409 .unwrap_or_else(|| inh.parent_name.clone());
410 chain.push(format!("- {}", label));
411 } else {
412 chain.push(format!(
413 "- {} (external / not in tree-sitter index)",
414 inh.parent_name
415 ));
416 }
417 }
418 if chain.is_empty() {
419 "(none)".to_string()
420 } else {
421 chain.join("\n")
422 }
423 }
424}
425
426#[derive(Debug, Clone, Default)]
427struct SharedStorageEdges {
428 inner: Arc<Mutex<HashMap<(i32, i32, bool), Option<String>>>>,
429}
430
431impl SharedStorageEdges {
432 async fn record(&self, fid: i32, sv: i32, is_write: bool, description: Option<String>) -> bool {
433 let mut g = self.inner.lock().await;
434 let key = (fid, sv, is_write);
435 if let std::collections::hash_map::Entry::Vacant(e) = g.entry(key) {
436 e.insert(description);
437 true
438 } else {
439 false
441 }
442 }
443
444 async fn snapshot(&self) -> Vec<FunctionStateVariable> {
445 let g = self.inner.lock().await;
446 let mut out: Vec<FunctionStateVariable> = g
447 .iter()
448 .map(|((fid, sv, w), desc)| FunctionStateVariable {
449 function_id: *fid,
450 state_variable_id: *sv,
451 is_write: *w,
452 description: desc.clone(),
453 })
454 .collect();
455 out.sort_by_key(|r| (r.function_id, r.state_variable_id, r.is_write));
456 out
457 }
458}
459
460#[derive(Debug, Clone, Deserialize, JsonSchema)]
461pub struct RecordStateVarAccessArgs {
462 pub function_id: i32,
464 pub state_variable_id: i32,
466 pub is_write: bool,
468 #[serde(default)]
470 pub description: Option<String>,
471}
472
473#[derive(Debug, Clone)]
474#[llmy::agent::tool(
475 arguments = RecordStateVarAccessArgs,
476 invoke = record_state_variable_access,
477 name = "record_state_variable_access",
478 description = "Record one (function -> state-variable) read or write edge for the contract under analysis. Pass the project-local function_id from the function index and the project-local state_variable_id from the visible state-var index. Use is_write=true for any write (assignment / delete / push / pop / compound op / ++ / --). Repeated calls with the same (function_id, state_variable_id, is_write) triple are idempotent. If the function_id or state_variable_id is not in the indexes for this contract, the tool refuses and returns the candidate ids; fix the selector and retry.",
479)]
480struct RecordStateVarAccessTool {
481 valid_function_ids: HashSet<i32>,
482 valid_state_var_ids: HashSet<i32>,
483 store: SharedStorageEdges,
484}
485
486impl RecordStateVarAccessTool {
487 fn new(
488 valid_function_ids: HashSet<i32>,
489 valid_state_var_ids: HashSet<i32>,
490 store: SharedStorageEdges,
491 ) -> Self {
492 Self {
493 valid_function_ids,
494 valid_state_var_ids,
495 store,
496 }
497 }
498
499 async fn record_state_variable_access(
500 &self,
501 args: RecordStateVarAccessArgs,
502 ) -> Result<String, LLMYError> {
503 if !self.valid_function_ids.contains(&args.function_id) {
504 let mut ids: Vec<i32> = self.valid_function_ids.iter().copied().collect();
505 ids.sort();
506 return Ok(format!(
507 "rejected: function_id={} is not in this contract's function index. Valid function_ids: {:?}",
508 args.function_id, ids
509 ));
510 }
511 if !self.valid_state_var_ids.contains(&args.state_variable_id) {
512 let mut ids: Vec<i32> = self.valid_state_var_ids.iter().copied().collect();
513 ids.sort();
514 return Ok(format!(
515 "rejected: state_variable_id={} is not visible to this contract. Visible state_variable_ids: {:?}",
516 args.state_variable_id, ids
517 ));
518 }
519 let inserted = self
520 .store
521 .record(
522 args.function_id,
523 args.state_variable_id,
524 args.is_write,
525 args.description.clone(),
526 )
527 .await;
528 Ok(format!(
529 "{} ({} -> sv{} {})",
530 if inserted {
531 "recorded"
532 } else {
533 "already recorded"
534 },
535 args.function_id,
536 args.state_variable_id,
537 if args.is_write { "WRITE" } else { "READ" }
538 ))
539 }
540}
541
542fn render_with_line_numbers(content: &str, start_line: usize) -> String {
543 let mut out = String::new();
544 for (i, line) in content.lines().enumerate() {
545 let _ = writeln!(out, "{:5} {}", start_line + i, line);
546 }
547 out
548}