Skip to main content

knowdit_sol/
cg.rs

1use std::{fmt, path::PathBuf};
2
3use color_eyre::eyre::{ContextCompat, WrapErr, ensure};
4use knowdit_repo_model::cg::FileChunk;
5use serde::Serialize;
6use tree_sitter::Parser;
7
8pub use crate::filter::{
9    SolidityExtractionConfig, filter_analysis_source_files, normalize_relative_source_files,
10};
11use crate::node::{
12    callable_args, callable_kind_from_node, callable_name, collect_contract_nodes, node_chunk,
13    node_field_text,
14};
15
16/// One Solidity source file plus its already-loaded content. The
17/// tree-sitter pass needs the source bytes; this struct lets callers
18/// that have already read the file (e.g. a [`crate::filter::SolidityExtractionConfig`]-
19/// driven CLI flow or a `knowdit_project::ProjectScopeContent` snapshot)
20/// pass the content straight through instead of forcing a second
21/// `tokio::fs::read_to_string` per file.
22#[derive(Debug, Clone)]
23pub struct SoliditySourceInput {
24    /// Path of the source file relative to the project root. Stored
25    /// verbatim on each extracted contract / callable so downstream
26    /// consumers can hash-locate them.
27    pub relative_path: PathBuf,
28    /// Full source text of the file. tree-sitter parses this directly.
29    pub content: String,
30}
31
32/// Contracts/functions extracted from a Solidity repository before any LLM callgraph analysis.
33#[derive(Debug, Clone, Serialize)]
34pub struct SolidityExtractionResult {
35    pub repo_root: PathBuf,
36    pub source_files: Vec<PathBuf>,
37    pub analysis_source_files: Vec<PathBuf>,
38    pub contracts: Vec<ExtractedContract>,
39}
40
41/// Solidity declaration kind for a contract-like source item.
42#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)]
43pub enum SolidityContractKind {
44    Contract,
45    Interface,
46    Library,
47}
48
49impl SolidityContractKind {
50    pub(crate) fn from_node_kind(kind: &str) -> Option<Self> {
51        match kind {
52            "contract_declaration" => Some(Self::Contract),
53            "interface_declaration" => Some(Self::Interface),
54            "library_declaration" => Some(Self::Library),
55            _ => None,
56        }
57    }
58
59    pub fn as_str(self) -> &'static str {
60        match self {
61            Self::Contract => "contract",
62            Self::Interface => "interface",
63            Self::Library => "library",
64        }
65    }
66}
67
68/// Solidity callable kind represented as a callgraph function node.
69#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)]
70pub enum SolidityCallableKind {
71    Function,
72    Constructor,
73    Receive,
74    Fallback,
75    Modifier,
76}
77
78impl SolidityCallableKind {
79    pub fn as_str(self) -> &'static str {
80        match self {
81            Self::Function => "function",
82            Self::Constructor => "constructor",
83            Self::Receive => "receive",
84            Self::Fallback => "fallback",
85            Self::Modifier => "modifier",
86        }
87    }
88}
89
90/// Repository callgraph function node kind.
91#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)]
92pub enum SolidityFunctionNodeKind {
93    ContractFunctionDefinition,
94    InterfaceFunctionDeclaration,
95}
96
97impl SolidityFunctionNodeKind {
98    fn from_container_kind(kind: SolidityContractKind) -> Self {
99        match kind {
100            SolidityContractKind::Interface => Self::InterfaceFunctionDeclaration,
101            SolidityContractKind::Contract | SolidityContractKind::Library => {
102                Self::ContractFunctionDefinition
103            }
104        }
105    }
106
107    pub fn as_str(self) -> &'static str {
108        match self {
109            Self::ContractFunctionDefinition => "contract_function_definition",
110            Self::InterfaceFunctionDeclaration => "interface_function_declaration",
111        }
112    }
113
114    pub fn is_definition(self) -> bool {
115        matches!(self, Self::ContractFunctionDefinition)
116    }
117}
118
119/// A contract/interface/library extracted from Solidity source.
120#[derive(Debug, Clone, Serialize)]
121pub struct ExtractedContract {
122    pub id: i32,
123    pub name: String,
124    pub kind: SolidityContractKind,
125    pub relative_file_path: PathBuf,
126    pub chunk: FileChunk,
127    pub functions: Vec<ExtractedCallable>,
128}
129
130impl fmt::Display for ExtractedContract {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        write!(
133            f,
134            "Contract({} {} @ {} (id={}, functions={}))",
135            self.kind.as_str(),
136            self.name,
137            self.relative_file_path.display(),
138            self.id,
139            self.functions.len()
140        )
141    }
142}
143
144/// A function-like Solidity item extracted by tree-sitter.
145#[derive(Debug, Clone, Serialize)]
146pub struct ExtractedCallable {
147    pub id: i32,
148    pub contract_id: i32,
149    pub contract_name: String,
150    pub kind: SolidityCallableKind,
151    pub node_kind: SolidityFunctionNodeKind,
152    pub name: String,
153    pub args: String,
154    pub relative_file_path: PathBuf,
155    pub chunk: FileChunk,
156}
157
158/// Extract Solidity contracts/functions without invoking the LLM.
159pub async fn extract_repo_contracts_functions(
160    config: &SolidityExtractionConfig,
161) -> Result<SolidityExtractionResult, color_eyre::Report> {
162    let repo_root = config.repo_root.canonicalize().wrap_err_with(|| {
163        format!(
164            "failed to canonicalize repo root {}",
165            config.repo_root.display()
166        )
167    })?;
168    ensure!(
169        repo_root.is_dir(),
170        "repo root {} is not a directory",
171        repo_root.display()
172    );
173
174    let source_files = normalize_relative_source_files(&repo_root, config.source_files.clone())?;
175    ensure!(
176        !source_files.is_empty(),
177        "no Solidity source files provided under {}",
178        repo_root.display()
179    );
180    let analysis_source_files = filter_analysis_source_files(
181        source_files.clone(),
182        normalize_relative_source_files(&repo_root, config.analysis_source_files.clone())?,
183    )?;
184
185    let mut inputs: Vec<SoliditySourceInput> = Vec::with_capacity(source_files.len());
186    for relative in &source_files {
187        let absolute = repo_root.join(relative);
188        let content = tokio::fs::read_to_string(&absolute)
189            .await
190            .wrap_err_with(|| format!("failed to read Solidity file {}", absolute.display()))?;
191        inputs.push(SoliditySourceInput {
192            relative_path: relative.clone(),
193            content,
194        });
195    }
196    let contracts = extract_contracts_functions(&inputs)?;
197    ensure!(
198        !contracts.is_empty(),
199        "no Solidity contracts/interfaces/libraries were extracted under {}",
200        repo_root.display()
201    );
202
203    Ok(SolidityExtractionResult {
204        repo_root,
205        source_files,
206        analysis_source_files,
207        contracts,
208    })
209}
210
211/// Extract Solidity contracts and callables with tree-sitter from
212/// pre-loaded source content. Sync: tree-sitter parsing is CPU-bound
213/// and the caller has already taken care of disk I/O. The previous
214/// `(repo_root, &[PathBuf])` signature is preserved at the outer
215/// [`extract_repo_contracts_functions`] entry point, which handles
216/// reading + then defers here.
217pub fn extract_contracts_functions(
218    inputs: &[SoliditySourceInput],
219) -> Result<Vec<ExtractedContract>, color_eyre::Report> {
220    let mut parser = Parser::new();
221    let language = tree_sitter_solidity::LANGUAGE.into();
222    parser
223        .set_language(&language)
224        .wrap_err("failed to load tree-sitter Solidity grammar")?;
225
226    let mut contracts = Vec::new();
227    let mut next_contract_id = 1;
228    let mut next_function_id = 1;
229
230    for input in inputs {
231        let relative_file_path = &input.relative_path;
232        let source = &input.content;
233        let tree = parser.parse(source.as_bytes(), None).wrap_err_with(|| {
234            format!(
235                "failed to parse Solidity file {}",
236                relative_file_path.display()
237            )
238        })?;
239        let root = tree.root_node();
240        if root.has_error() {
241            tracing::warn!(
242                path = %relative_file_path.display(),
243                "tree-sitter reported syntax errors; attempting best-effort extraction"
244            );
245        }
246
247        let mut contract_nodes = Vec::new();
248        collect_contract_nodes(root, &mut contract_nodes);
249
250        for contract_node in contract_nodes {
251            let Some(contract_kind) = SolidityContractKind::from_node_kind(contract_node.kind())
252            else {
253                continue;
254            };
255            let name = node_field_text(contract_node, "name", source).wrap_err_with(|| {
256                format!(
257                    "failed to read contract name in {} at byte {}",
258                    relative_file_path.display(),
259                    contract_node.start_byte()
260                )
261            })?;
262            let contract_id = next_contract_id;
263            next_contract_id += 1;
264
265            let mut functions = Vec::new();
266            if let Some(body) = contract_node.child_by_field_name("body") {
267                let mut cursor = body.walk();
268                let body_children = body.named_children(&mut cursor).collect::<Vec<_>>();
269                for child in body_children {
270                    let Some(callable_kind) = callable_kind_from_node(child, source) else {
271                        continue;
272                    };
273                    let callable_name = callable_name(child, callable_kind, source)?;
274                    let args = callable_args(child, source)?;
275                    let chunk = node_chunk(child, source)?;
276                    functions.push(ExtractedCallable {
277                        id: next_function_id,
278                        contract_id,
279                        contract_name: name.clone(),
280                        kind: callable_kind,
281                        node_kind: SolidityFunctionNodeKind::from_container_kind(contract_kind),
282                        name: callable_name,
283                        args,
284                        relative_file_path: relative_file_path.clone(),
285                        chunk,
286                    });
287                    next_function_id += 1;
288                }
289            }
290
291            contracts.push(ExtractedContract {
292                id: contract_id,
293                name,
294                kind: contract_kind,
295                relative_file_path: relative_file_path.clone(),
296                chunk: node_chunk(contract_node, source)?,
297                functions,
298            });
299        }
300    }
301
302    Ok(contracts)
303}