Skip to main content

knowdit_agents/
storage.rs

1use std::{
2    collections::{BTreeMap, HashMap, HashSet, VecDeque},
3    fmt::Write as _,
4    sync::Arc,
5};
6
7use color_eyre::eyre::{ContextCompat, WrapErr, ensure};
8use knowdit_repo_model::{
9    inheritance::{ContractInherit, InheritanceGraph},
10    storage::{ContractVariable, FunctionStateVariable, StateVariable, StorageGraph},
11};
12use knowdit_sol::{
13    cg::{
14        ExtractedContract, SolidityCallableKind, SolidityExtractionConfig,
15        extract_repo_contracts_functions,
16    },
17    storage::{
18        ExtractedInheritance, ExtractedStateVariable, extract_state_variables_and_inheritance,
19    },
20};
21use llmy::{
22    agent::{LLMYError, StepResult, tool::ToolBox},
23    client::{client::LLM, settings::LLMSettings},
24    harness::Agent,
25};
26use schemars::JsonSchema;
27use serde::Deserialize;
28use tokio::sync::Mutex;
29
30use crate::storage_prompt::{
31    StorageAgentPromptInput, storage_agent_system_prompt, storage_agent_user_prompt,
32};
33
34#[derive(Debug, Clone)]
35pub struct StorageAgentConfig {
36    pub extraction: SolidityExtractionConfig,
37    pub max_agent_steps: usize,
38    pub cache_key: String,
39    pub debug_prefix: Option<String>,
40    pub llm_settings: Option<LLMSettings>,
41}
42
43impl StorageAgentConfig {
44    pub fn new(
45        extraction: SolidityExtractionConfig,
46        max_agent_steps: usize,
47        cache_key: impl Into<String>,
48        debug_prefix: Option<String>,
49        llm_settings: Option<LLMSettings>,
50    ) -> Self {
51        Self {
52            extraction,
53            max_agent_steps,
54            cache_key: cache_key.into(),
55            debug_prefix,
56            llm_settings,
57        }
58    }
59}
60
61#[derive(Debug, Clone)]
62pub struct StorageAgentResult {
63    pub storage_graph: StorageGraph,
64    pub inheritance_graph: InheritanceGraph,
65    /// Tree-sitter extracted contracts + functions, exposed so the CLI can persist a callgraph
66    /// skeleton (contracts/functions/links, no call edges) using the same project-local ids
67    /// as the storage analysis.
68    pub extracted_contracts: Vec<ExtractedContract>,
69    pub steps: usize,
70    pub final_response: String,
71}
72
73pub async fn analyze_repo_storage(
74    llm: &LLM,
75    config: StorageAgentConfig,
76) -> Result<StorageAgentResult, color_eyre::Report> {
77    let extraction = extract_repo_contracts_functions(&config.extraction).await?;
78    let storage_extraction =
79        extract_state_variables_and_inheritance(&extraction.repo_root, &extraction.contracts)
80            .await?;
81
82    let extracted_contracts = extraction.contracts.clone();
83    let runner = StorageRunner::new(extraction.contracts, storage_extraction);
84    let mut result = runner.run(llm, &config).await?;
85    result.extracted_contracts = extracted_contracts;
86    Ok(result)
87}
88
89struct StorageRunner {
90    contract_by_id: BTreeMap<i32, ExtractedContract>,
91    state_variables: Vec<ExtractedStateVariable>,
92    state_variable_by_id: BTreeMap<i32, ExtractedStateVariable>,
93    /// Project-local inheritance edges (parent_contract_id known).
94    contract_inherits: Vec<ContractInherit>,
95    /// All `is X` rows including unresolved external parents (for prompt context).
96    raw_inherits: Vec<ExtractedInheritance>,
97    /// State vars visible to each contract = own + transitive parents.
98    visible_state_var_ids: BTreeMap<i32, Vec<i32>>,
99}
100
101impl StorageRunner {
102    fn new(
103        contracts: Vec<ExtractedContract>,
104        storage_extraction: knowdit_sol::storage::StorageExtractionResult,
105    ) -> Self {
106        let contract_by_id = contracts
107            .iter()
108            .cloned()
109            .map(|c| (c.id, c))
110            .collect::<BTreeMap<_, _>>();
111        let state_variable_by_id = storage_extraction
112            .state_variables
113            .iter()
114            .cloned()
115            .map(|v| (v.id, v))
116            .collect::<BTreeMap<_, _>>();
117
118        let mut contract_inherits: Vec<ContractInherit> = Vec::new();
119        for inh in &storage_extraction.inherits {
120            if let Some(parent_id) = inh.parent_contract_id {
121                contract_inherits.push(ContractInherit {
122                    contract_id: inh.contract_id,
123                    inherited_id: parent_id,
124                });
125            }
126        }
127        contract_inherits.sort_by_key(|i| (i.contract_id, i.inherited_id));
128        contract_inherits.dedup_by_key(|i| (i.contract_id, i.inherited_id));
129
130        // Build adjacency: child -> parents
131        let mut parents: BTreeMap<i32, Vec<i32>> = BTreeMap::new();
132        for ci in &contract_inherits {
133            parents
134                .entry(ci.contract_id)
135                .or_default()
136                .push(ci.inherited_id);
137        }
138
139        // own_state_var_ids
140        let mut own_state_var_ids: BTreeMap<i32, Vec<i32>> = BTreeMap::new();
141        for sv in &storage_extraction.state_variables {
142            own_state_var_ids
143                .entry(sv.contract_id)
144                .or_default()
145                .push(sv.id);
146        }
147
148        // visible_state_var_ids: BFS over inheritance closure
149        let mut visible_state_var_ids: BTreeMap<i32, Vec<i32>> = BTreeMap::new();
150        for c in &contracts {
151            let mut visited: HashSet<i32> = HashSet::new();
152            let mut queue: VecDeque<i32> = VecDeque::new();
153            queue.push_back(c.id);
154            let mut visible: Vec<i32> = Vec::new();
155            while let Some(cid) = queue.pop_front() {
156                if !visited.insert(cid) {
157                    continue;
158                }
159                if let Some(svs) = own_state_var_ids.get(&cid) {
160                    visible.extend(svs.iter().copied());
161                }
162                if let Some(ps) = parents.get(&cid) {
163                    for &p in ps {
164                        if !visited.contains(&p) {
165                            queue.push_back(p);
166                        }
167                    }
168                }
169            }
170            visible.sort();
171            visible.dedup();
172            visible_state_var_ids.insert(c.id, visible);
173        }
174
175        Self {
176            contract_by_id,
177            state_variables: storage_extraction.state_variables.clone(),
178            state_variable_by_id,
179            contract_inherits,
180            raw_inherits: storage_extraction.inherits,
181            visible_state_var_ids,
182        }
183    }
184
185    async fn run(
186        self,
187        llm: &LLM,
188        config: &StorageAgentConfig,
189    ) -> Result<StorageAgentResult, color_eyre::Report> {
190        let store = SharedStorageEdges::default();
191        let mut summaries: Vec<String> = Vec::new();
192        let mut total_steps = 0usize;
193
194        // Iterate contracts in deterministic order. Skip Interface kind (no bodies, no R/W).
195        let mut contract_ids: Vec<i32> = self.contract_by_id.keys().copied().collect();
196        contract_ids.sort();
197
198        for cid in contract_ids {
199            let contract = self
200                .contract_by_id
201                .get(&cid)
202                .cloned()
203                .wrap_err("missing contract")?;
204            // Skip interfaces and library/contract entries that have no functions and no state vars.
205            let has_functions = contract.functions.iter().any(|f| {
206                matches!(
207                    f.kind,
208                    SolidityCallableKind::Function
209                        | SolidityCallableKind::Constructor
210                        | SolidityCallableKind::Receive
211                        | SolidityCallableKind::Fallback
212                        | SolidityCallableKind::Modifier
213                )
214            });
215            let visible_svs = self
216                .visible_state_var_ids
217                .get(&cid)
218                .cloned()
219                .unwrap_or_default();
220            // Drop constants — agent doesn't need to see them.
221            let visible_svs: Vec<i32> = visible_svs
222                .into_iter()
223                .filter(|sv_id| {
224                    self.state_variable_by_id
225                        .get(sv_id)
226                        .map(|sv| !sv.is_constant)
227                        .unwrap_or(false)
228                })
229                .collect();
230            if !has_functions || visible_svs.is_empty() {
231                continue;
232            }
233
234            let function_index_table = self.render_function_index(&contract);
235            let state_variable_index_table = self.render_state_var_index(&visible_svs);
236            let inheritance_chain = self.render_inheritance_chain(cid);
237            let contract_source =
238                render_with_line_numbers(&contract.chunk.content, contract.chunk.loc.start_line);
239
240            let prompt_input = StorageAgentPromptInput {
241                contract_label: format!(
242                    "{} {} ({})",
243                    contract.kind.as_str(),
244                    contract.name,
245                    contract.relative_file_path.display()
246                ),
247                contract_source,
248                function_index_table,
249                state_variable_index_table,
250                inheritance_chain,
251            };
252
253            let mut tools = ToolBox::new();
254            tools.add_tool(RecordStateVarAccessTool::new(
255                contract
256                    .functions
257                    .iter()
258                    .map(|f| f.id)
259                    .collect::<HashSet<_>>(),
260                visible_svs.iter().copied().collect::<HashSet<_>>(),
261                store.clone(),
262            ));
263
264            let cache_key = format!("{}-c{}", config.cache_key, contract.id);
265            let mut agent = Agent::new(storage_agent_system_prompt(), tools, cache_key);
266            let user = storage_agent_user_prompt(prompt_input);
267            let mut steps = 1;
268            let mut step_result = agent
269                .step_with_user(
270                    user,
271                    llm,
272                    config.debug_prefix.as_deref(),
273                    config.llm_settings.clone(),
274                )
275                .await
276                .wrap_err_with(|| {
277                    format!(
278                        "storage agent initial step failed for contract {}",
279                        contract.name
280                    )
281                })?;
282
283            loop {
284                if let StepResult::Stop(summary) = &step_result {
285                    summaries.push(format!(
286                        "{}::{}",
287                        contract.relative_file_path.display(),
288                        contract.name
289                    ));
290                    summaries.push(summary.trim().to_string());
291                    break;
292                }
293                ensure!(
294                    steps < config.max_agent_steps,
295                    "storage agent exceeded max_agent_steps={} on contract {}",
296                    config.max_agent_steps,
297                    contract.name
298                );
299                steps += 1;
300                step_result = agent
301                    .step(
302                        llm,
303                        config.debug_prefix.as_deref(),
304                        config.llm_settings.clone(),
305                    )
306                    .await
307                    .wrap_err_with(|| {
308                        format!(
309                            "storage agent step {} failed for contract {}",
310                            steps, contract.name
311                        )
312                    })?;
313            }
314            total_steps += steps;
315        }
316
317        // Build StorageGraph
318        let mut state_variables_map: BTreeMap<i32, StateVariable> = BTreeMap::new();
319        for sv in &self.state_variables {
320            state_variables_map.insert(
321                sv.id,
322                StateVariable {
323                    id: sv.id,
324                    name: sv.name.clone(),
325                    type_name: sv.type_text.clone(),
326                    relative_file_path: sv.relative_file_path.clone(),
327                    loc: sv.chunk.loc,
328                    content: sv.chunk.content.clone(),
329                },
330            );
331        }
332        let mut contract_variables: Vec<ContractVariable> = self
333            .state_variables
334            .iter()
335            .map(|sv| ContractVariable {
336                contract_id: sv.contract_id,
337                state_variable_id: sv.id,
338                description: None,
339            })
340            .collect();
341        contract_variables.sort_by_key(|cv| (cv.contract_id, cv.state_variable_id));
342        let recorded = store.snapshot().await;
343        let storage_graph = StorageGraph {
344            state_variables: state_variables_map,
345            contract_variables,
346            function_state_variables: recorded,
347        };
348        let inheritance_graph = InheritanceGraph::new(self.contract_inherits.clone());
349
350        Ok(StorageAgentResult {
351            storage_graph,
352            inheritance_graph,
353            extracted_contracts: Vec::new(),
354            steps: total_steps,
355            final_response: summaries.join("\n"),
356        })
357    }
358
359    fn render_function_index(&self, contract: &ExtractedContract) -> String {
360        let mut s = String::new();
361        for f in &contract.functions {
362            let _ = writeln!(
363                s,
364                "- function_id={} kind={} name={} args=({}) line={}",
365                f.id,
366                f.kind.as_str(),
367                f.name,
368                f.args,
369                f.chunk.loc.start_line
370            );
371        }
372        if s.is_empty() {
373            "(no functions)".to_string()
374        } else {
375            s
376        }
377    }
378
379    fn render_state_var_index(&self, ids: &[i32]) -> String {
380        let mut s = String::new();
381        for sv_id in ids {
382            let Some(sv) = self.state_variable_by_id.get(sv_id) else {
383                continue;
384            };
385            let imm = if sv.is_immutable { " [immutable]" } else { "" };
386            let _ = writeln!(
387                s,
388                "- state_variable_id={} contract={} name={} type={}{}",
389                sv.id, sv.contract_name, sv.name, sv.type_text, imm
390            );
391        }
392        if s.is_empty() {
393            "(none — contract has no visible state variables)".to_string()
394        } else {
395            s
396        }
397    }
398
399    fn render_inheritance_chain(&self, contract_id: i32) -> String {
400        let mut chain: Vec<String> = Vec::new();
401        for inh in &self.raw_inherits {
402            if inh.contract_id != contract_id {
403                continue;
404            }
405            if let Some(pid) = inh.parent_contract_id {
406                let parent = self.contract_by_id.get(&pid);
407                let label = parent
408                    .map(|p| format!("{} ({})", p.name, p.relative_file_path.display()))
409                    .unwrap_or_else(|| inh.parent_name.clone());
410                chain.push(format!("- {}", label));
411            } else {
412                chain.push(format!(
413                    "- {} (external / not in tree-sitter index)",
414                    inh.parent_name
415                ));
416            }
417        }
418        if chain.is_empty() {
419            "(none)".to_string()
420        } else {
421            chain.join("\n")
422        }
423    }
424}
425
426#[derive(Debug, Clone, Default)]
427struct SharedStorageEdges {
428    inner: Arc<Mutex<HashMap<(i32, i32, bool), Option<String>>>>,
429}
430
431impl SharedStorageEdges {
432    async fn record(&self, fid: i32, sv: i32, is_write: bool, description: Option<String>) -> bool {
433        let mut g = self.inner.lock().await;
434        let key = (fid, sv, is_write);
435        if let std::collections::hash_map::Entry::Vacant(e) = g.entry(key) {
436            e.insert(description);
437            true
438        } else {
439            // Keep first description if present; ignore further.
440            false
441        }
442    }
443
444    async fn snapshot(&self) -> Vec<FunctionStateVariable> {
445        let g = self.inner.lock().await;
446        let mut out: Vec<FunctionStateVariable> = g
447            .iter()
448            .map(|((fid, sv, w), desc)| FunctionStateVariable {
449                function_id: *fid,
450                state_variable_id: *sv,
451                is_write: *w,
452                description: desc.clone(),
453            })
454            .collect();
455        out.sort_by_key(|r| (r.function_id, r.state_variable_id, r.is_write));
456        out
457    }
458}
459
460#[derive(Debug, Clone, Deserialize, JsonSchema)]
461pub struct RecordStateVarAccessArgs {
462    /// Project-local id of the function/modifier (must be one listed in the function index).
463    pub function_id: i32,
464    /// Project-local id of the state variable (must be one listed in the visible state var index).
465    pub state_variable_id: i32,
466    /// `true` for stores/deletes/array.push/array.pop/compound assigns/++/--; `false` for reads.
467    pub is_write: bool,
468    /// Optional 1-line description of the access (e.g. "post-increment", "delete on map slot").
469    #[serde(default)]
470    pub description: Option<String>,
471}
472
473#[derive(Debug, Clone)]
474#[llmy::agent::tool(
475    arguments = RecordStateVarAccessArgs,
476    invoke = record_state_variable_access,
477    name = "record_state_variable_access",
478    description = "Record one (function -> state-variable) read or write edge for the contract under analysis. Pass the project-local function_id from the function index and the project-local state_variable_id from the visible state-var index. Use is_write=true for any write (assignment / delete / push / pop / compound op / ++ / --). Repeated calls with the same (function_id, state_variable_id, is_write) triple are idempotent. If the function_id or state_variable_id is not in the indexes for this contract, the tool refuses and returns the candidate ids; fix the selector and retry.",
479)]
480struct RecordStateVarAccessTool {
481    valid_function_ids: HashSet<i32>,
482    valid_state_var_ids: HashSet<i32>,
483    store: SharedStorageEdges,
484}
485
486impl RecordStateVarAccessTool {
487    fn new(
488        valid_function_ids: HashSet<i32>,
489        valid_state_var_ids: HashSet<i32>,
490        store: SharedStorageEdges,
491    ) -> Self {
492        Self {
493            valid_function_ids,
494            valid_state_var_ids,
495            store,
496        }
497    }
498
499    async fn record_state_variable_access(
500        &self,
501        args: RecordStateVarAccessArgs,
502    ) -> Result<String, LLMYError> {
503        if !self.valid_function_ids.contains(&args.function_id) {
504            let mut ids: Vec<i32> = self.valid_function_ids.iter().copied().collect();
505            ids.sort();
506            return Ok(format!(
507                "rejected: function_id={} is not in this contract's function index. Valid function_ids: {:?}",
508                args.function_id, ids
509            ));
510        }
511        if !self.valid_state_var_ids.contains(&args.state_variable_id) {
512            let mut ids: Vec<i32> = self.valid_state_var_ids.iter().copied().collect();
513            ids.sort();
514            return Ok(format!(
515                "rejected: state_variable_id={} is not visible to this contract. Visible state_variable_ids: {:?}",
516                args.state_variable_id, ids
517            ));
518        }
519        let inserted = self
520            .store
521            .record(
522                args.function_id,
523                args.state_variable_id,
524                args.is_write,
525                args.description.clone(),
526            )
527            .await;
528        Ok(format!(
529            "{} ({} -> sv{} {})",
530            if inserted {
531                "recorded"
532            } else {
533                "already recorded"
534            },
535            args.function_id,
536            args.state_variable_id,
537            if args.is_write { "WRITE" } else { "READ" }
538        ))
539    }
540}
541
542fn render_with_line_numbers(content: &str, start_line: usize) -> String {
543    let mut out = String::new();
544    for (i, line) in content.lines().enumerate() {
545        let _ = writeln!(out, "{:5} {}", start_line + i, line);
546    }
547    out
548}