edb_engine/utils/
ast_prune.rs1use std::path::PathBuf;
18
19use eyre::{OptionExt, Result};
20use foundry_compilers::{
21 artifacts::{
22 output_selection::OutputSelection, Ast, Node, NodeType, Settings, Severity, Source,
23 SourceUnit, Sources,
24 },
25 solc::{SolcCompiler, SolcLanguage, SolcSettings, SolcVersionedInput},
26 CompilationError, Compiler, CompilerInput,
27};
28use semver::Version;
29
30pub fn compile_contract_source_to_source_unit(
32 solc_version: Version,
33 source: &str,
34 prune: bool,
35) -> Result<SourceUnit> {
36 let phantom_file_name = PathBuf::from("Contract.sol");
37 let sources = Sources::from_iter([(phantom_file_name.clone(), Source::new(source))]);
38 let settings = SolcSettings {
39 settings: Settings::new(OutputSelection::complete_output_selection()),
40 cli_settings: Default::default(),
41 };
42 let solc_input =
43 SolcVersionedInput::build(sources, settings, SolcLanguage::Solidity, solc_version);
44 let compiler = SolcCompiler::AutoDetect;
45 let output = compiler.compile(&solc_input)?;
46
47 let errors = output
49 .errors
50 .iter()
51 .filter(|e| e.severity() == Severity::Error)
52 .map(|e| format!("{e}"))
53 .collect::<Vec<_>>();
54 if !errors.is_empty() {
55 return Err(eyre::eyre!("Compiler error: {}", errors.join("\n")));
56 }
57
58 let mut ast = output
59 .sources
60 .get(&phantom_file_name)
61 .expect("No AST found")
62 .ast
63 .clone()
64 .expect("AST is not selected as output");
65
66 let source_unit = ASTPruner::convert(&mut ast, prune)?;
67 Ok(source_unit)
68}
69
70pub struct ASTPruner {}
88
89impl ASTPruner {
90 pub fn convert(ast: &mut Ast, prune: bool) -> Result<SourceUnit> {
92 if prune {
93 Self::prune(ast)?;
94 }
95 let serialized = serde_json::to_string(ast)?;
96
97 Ok(serde_json::from_str(&serialized)?)
98 }
99
100 fn prune(ast: &mut Ast) -> Result<()> {
101 for node in ast.nodes.iter_mut() {
102 Self::prune_node(node)?;
103 }
104
105 for (field, value) in ast.other.iter_mut() {
106 if field == "documentation" {
107 *value = serde_json::Value::Null;
109 } else {
110 Self::prune_value(value)?;
111 }
112 }
113
114 Ok(())
115 }
116
117 fn prune_node(node: &mut Node) -> Result<()> {
118 if matches!(node.node_type, NodeType::InlineAssembly) && !node.other.contains_key("AST") {
120 let ast = serde_json::json!({
124 "nodeType": "YulBlock",
125 "src": node.src,
126 "statements": [],
127 });
128 node.other.insert("AST".to_string(), ast);
129
130 node.other.insert("externalReferences".to_string(), serde_json::json!([]));
132
133 node.other.remove("operations");
135 }
136
137 if matches!(node.node_type, NodeType::ImportDirective) {
139 node.other.insert("symbolAliases".to_string(), serde_json::json!([]));
141 }
142
143 for (field, value) in node.other.iter_mut() {
145 if field == "documentation" {
146 *value = serde_json::Value::Null;
148 } else {
149 Self::prune_value(value)?;
150 }
151 }
152
153 if let Some(body) = &mut node.body {
154 Self::prune_node(body)?;
155 }
156
157 for node in node.nodes.iter_mut() {
158 Self::prune_node(node)?;
159 }
160
161 Ok(())
162 }
163
164 fn prune_value(value: &mut serde_json::Value) -> Result<()> {
165 match value {
166 serde_json::Value::Object(obj) => {
167 if let Some(node_type) = obj.get("nodeType") {
169 if node_type.as_str() == Some("InlineAssembly") {
170 if !obj.contains_key("AST") {
173 let ast = serde_json::json!({
174 "nodeType": "YulBlock",
175 "src": obj.get("src").ok_or_eyre("missing src")?.clone(),
176 "statements": [],
177 });
178 obj.insert("AST".to_string(), ast);
179 }
180
181 obj.insert("externalReferences".to_string(), serde_json::json!([]));
183
184 obj.remove("operations");
186 }
187 }
188
189 if let Some(node_type) = obj.get("nodeType") {
191 if node_type.as_str() == Some("ImportDirective") {
192 obj.insert("symbolAliases".to_string(), serde_json::json!([]));
194 }
195 }
196
197 for (field, value) in obj.iter_mut() {
199 if field == "documentation" {
200 *value = serde_json::Value::Null;
202 } else {
203 Self::prune_value(value)?;
204 }
205 }
206 }
207 serde_json::Value::Array(arr) => {
208 for value in arr.iter_mut() {
209 Self::prune_value(value)?;
210 }
211 }
212 _ => {}
213 }
214
215 Ok(())
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use std::{path::PathBuf, str::FromStr, time::Duration};
222
223 use alloy_chains::Chain;
224 use alloy_primitives::Address;
225 use eyre::Result;
226 use foundry_block_explorers::Client;
227
228 use crate::utils::OnchainCompiler;
229
230 use super::*;
231
232 async fn download_and_compile(chain: Chain, addr: Address) -> Result<()> {
233 let cache_root = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
234 .join("../../testdata/cache/etherscan")
235 .join(chain.to_string());
236 let cache_ttl = Duration::from_secs(u32::MAX as u64); let client =
238 Client::builder().chain(chain)?.with_cache(Some(cache_root), cache_ttl).build()?;
239
240 let compiler_cache_root = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
241 .join("../../testdata/cache/solc")
242 .join(chain.to_string());
243 let compiler = OnchainCompiler::new(Some(compiler_cache_root))?;
244
245 let mut artifact =
246 compiler.compile(&client, addr).await?.ok_or_eyre("missing compiler output")?;
247 for (_, contract) in artifact.output.sources.iter_mut() {
248 ASTPruner::convert(contract.ast.as_mut().ok_or_eyre("AST does not exist")?, true)?;
249 }
250
251 Ok(())
252 }
253
254 #[tokio::test(flavor = "multi_thread")]
255 async fn test_solidity_external_library() {
256 let addr = Address::from_str("0x0F6E8eF18FB5bb61D545fEe60f779D8aED60408F").unwrap();
257 download_and_compile(Chain::default(), addr).await.unwrap();
258 }
259
260 #[tokio::test(flavor = "multi_thread")]
261 async fn test_solidity_v0_8_18() {
262 let addr = Address::from_str("0xe45dfc26215312edc131e34ea9299fbca53275ca").unwrap();
263 download_and_compile(Chain::default(), addr).await.unwrap();
264 }
265
266 #[tokio::test(flavor = "multi_thread")]
267 async fn test_solidity_v0_8_17() {
268 let addr = Address::from_str("0x1111111254eeb25477b68fb85ed929f73a960582").unwrap();
269 download_and_compile(Chain::default(), addr).await.unwrap();
270 }
271
272 #[tokio::test(flavor = "multi_thread")]
273 async fn test_solidity_v0_7_6() {
274 let addr = Address::from_str("0x1f98431c8ad98523631ae4a59f267346ea31f984").unwrap();
275 download_and_compile(Chain::default(), addr).await.unwrap();
276 }
277
278 #[tokio::test(flavor = "multi_thread")]
279 async fn test_solidity_v0_6_12() {
280 let addr = Address::from_str("0x1eb4cf3a948e7d72a198fe073ccb8c7a948cd853").unwrap();
281 download_and_compile(Chain::default(), addr).await.unwrap();
282 }
283
284 #[tokio::test(flavor = "multi_thread")]
285 async fn test_solidity_v0_5_17() {
286 let addr = Address::from_str("0xee39E4A6820FFc4eDaA80fD3b5A59788D515832b").unwrap();
287 download_and_compile(Chain::default(), addr).await.unwrap();
288 }
289
290 #[tokio::test(flavor = "multi_thread")]
291 async fn test_solidity_v0_4_24() {
292 let addr = Address::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
293 download_and_compile(Chain::default(), addr).await.unwrap();
294 }
295
296 #[test]
297 fn test_compile_contract_source() {
298 let source_code = r#"
300 // SPDX-License-Identifier: MIT
301 pragma solidity ^0.8.0;
302
303 contract SimpleStorage {
304 uint256 private storedData;
305
306 function set(uint256 x) public {
307 storedData = x;
308 }
309
310 function get() public view returns (uint256) {
311 return storedData;
312 }
313 }
314 "#;
315
316 let solc_version = Version::parse("0.8.0").expect("Invalid version");
318
319 let result = compile_contract_source_to_source_unit(solc_version, source_code, true);
321
322 assert!(result.is_ok(), "Compilation failed: {result:?}");
324
325 let source_unit = result.unwrap();
327
328 assert!(!source_unit.nodes.is_empty(), "No AST nodes found in source unit");
330 }
331}