1use std::{
18 sync::Arc,
20};
21
22use delegate::delegate;
24use foundry_compilers::artifacts::{
25 ast::SourceLocation,
26 BlockOrStatement,
27 DoWhileStatement,
28 Expression,
29 ForStatement,
30 FunctionCall,
31 FunctionDefinition,
32 IfStatement,
33 ModifierDefinition,
34 Statement,
35 TryStatement,
36 WhileStatement, };
38use once_cell::sync::OnceCell;
39use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
40use serde::{Deserialize, Serialize};
41
42use crate::analysis::{macros::universal_id, VariableRef, VariableScopeRef, UFID};
43
44universal_id! {
45 USID => 0
47}
48
49#[derive(Debug, Clone)]
55pub struct StepRef {
56 inner: Arc<RwLock<Step>>,
57 usid: OnceCell<USID>,
59 ufid: OnceCell<UFID>,
60 variant: OnceCell<StepVariant>,
61 function_calls: OnceCell<usize>,
62}
63
64impl From<Step> for StepRef {
65 fn from(step: Step) -> Self {
66 Self::new(step)
67 }
68}
69
70impl Serialize for StepRef {
71 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
72 where
73 S: serde::Serializer,
74 {
75 self.inner.read().serialize(serializer)
77 }
78}
79
80impl<'de> Deserialize<'de> for StepRef {
81 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
82 where
83 D: serde::Deserializer<'de>,
84 {
85 let step = Step::deserialize(deserializer)?;
87 Ok(Self::new(step))
88 }
89}
90
91impl StepRef {
92 pub fn new(inner: Step) -> Self {
94 Self {
95 inner: Arc::new(RwLock::new(inner)),
96 usid: OnceCell::new(),
97 ufid: OnceCell::new(),
98 variant: OnceCell::new(),
99 function_calls: OnceCell::new(),
100 }
101 }
102
103 pub(crate) fn read(&self) -> RwLockReadGuard<'_, Step> {
104 self.inner.read()
105 }
106
107 pub(crate) fn write(&self) -> RwLockWriteGuard<'_, Step> {
108 self.inner.write()
109 }
110
111 pub fn usid(&self) -> USID {
113 *self.usid.get_or_init(|| self.inner.read().usid)
114 }
115
116 pub fn ufid(&self) -> UFID {
118 *self.ufid.get_or_init(|| self.inner.read().ufid)
119 }
120
121 pub fn variant(&self) -> &StepVariant {
123 self.variant.get_or_init(|| self.inner.read().variant.clone())
124 }
125
126 pub fn function_calls(&self) -> usize {
128 let calls = &self.inner.read().function_calls;
130 let mut function_calls = calls.len();
131
132 match self.variant() {
136 StepVariant::Statement(Statement::EmitStatement { .. }) => {
137 function_calls = function_calls.saturating_sub(1);
138 }
139 StepVariant::Statements(ref stmts) => {
140 let emit_n = stmts
141 .iter()
142 .filter(|stmt| matches!(stmt, Statement::EmitStatement { .. }))
143 .count();
144 function_calls = function_calls.saturating_sub(emit_n);
145 }
146 _ => {}
147 }
148
149 static BUILT_IN_FUNCTIONS: &[&str] =
151 &["require", "assert", "keccak256", "sha256", "ripemd160", "ecrecover", "type"];
152 let built_in_n = calls
153 .iter()
154 .filter(|call| {
155 if let Expression::Identifier(ref id) = call.expression {
156 BUILT_IN_FUNCTIONS.contains(&id.name.as_str())
157 } else {
158 false
159 }
160 })
161 .count();
162 function_calls = function_calls.saturating_sub(built_in_n);
163
164 *self.function_calls.get_or_init(|| function_calls)
165 }
166
167 pub fn function_entry(&self) -> Option<UFID> {
169 if let StepVariant::FunctionEntry(_) = self.variant() {
170 Some(self.read().ufid)
171 } else {
172 None
173 }
174 }
175
176 pub fn modifier_entry(&self) -> Option<UFID> {
178 if let StepVariant::ModifierEntry(_) = self.variant() {
179 Some(self.read().ufid)
180 } else {
181 None
182 }
183 }
184
185 pub fn contains_return(&self) -> bool {
187 match self.variant() {
188 StepVariant::Statement(Statement::Return(..)) => true,
189 StepVariant::Statements(stmts) => {
190 stmts.iter().any(|s| matches!(s, Statement::Return(..)))
191 }
192 _ => false,
194 }
195 }
196}
197
198impl StepRef {
199 delegate! {
200 to self.inner.read() {
201 }
202 }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct Step {
221 pub usid: USID,
223 pub ufid: UFID,
225 pub variant: StepVariant,
227 pub src: SourceLocation,
229 pub function_calls: Vec<FunctionCall>,
231 pub accessible_variables: Vec<VariableRef>,
233 pub declared_variables: Vec<VariableRef>,
235 pub updated_variables: Vec<VariableRef>,
237 pub scope: VariableScopeRef,
239}
240
241impl Step {
242 pub fn new(
253 ufid: UFID,
254 variant: StepVariant,
255 src: SourceLocation,
256 scope: VariableScopeRef,
257 accessible_variables: Vec<VariableRef>,
258 ) -> Self {
259 let usid = USID::next();
260 Self {
261 usid,
262 ufid,
263 variant,
264 src,
265 function_calls: vec![],
266 accessible_variables,
267 declared_variables: vec![],
268 updated_variables: vec![],
269 scope,
270 }
271 }
272}
273
274#[derive(Debug, Clone, Serialize, Deserialize)]
276#[allow(clippy::large_enum_variant)]
277pub enum StepVariant {
278 FunctionEntry(FunctionDefinition),
280 ModifierEntry(ModifierDefinition),
282 Statement(Statement),
284 Statements(Vec<Statement>),
286 IfCondition(IfStatement),
288 ForLoop(ForStatement),
290 WhileLoop(WhileStatement),
292 DoWhileLoop(DoWhileStatement),
294 Try(TryStatement),
296}
297
298pub fn sloc_ldiff(a: SourceLocation, b: SourceLocation) -> SourceLocation {
301 assert_eq!(a.index, b.index, "The index of `a` and `b` must be the same");
302 let length = b.start.zip(a.start).map(|(end, start)| end.saturating_sub(start));
303 SourceLocation { start: a.start, length, index: a.index }
304}
305
306pub fn sloc_rdiff(a: SourceLocation, b: SourceLocation) -> SourceLocation {
309 assert_eq!(a.index, b.index, "The index of `a` and `b` must be the same");
310 let start = b.start.zip(b.length).map(|(start, length)| start + length);
311 let length = a
312 .start
313 .zip(a.length)
314 .map(|(start, length)| start + length)
315 .zip(start)
316 .map(|(end, start)| end.saturating_sub(start));
317 SourceLocation { start, length, index: a.index }
318}
319
320pub fn stmt_src(stmt: &Statement) -> SourceLocation {
322 match stmt {
323 Statement::Block(block) => block.src,
324 Statement::ExpressionStatement(expression_statement) => expression_statement.src,
325 Statement::Break(break_stmt) => break_stmt.src,
326 Statement::Continue(continue_stmt) => continue_stmt.src,
327 Statement::DoWhileStatement(do_while_statement) => do_while_statement.src,
328 Statement::EmitStatement(emit_statement) => emit_statement.src,
329 Statement::ForStatement(for_statement) => for_statement.src,
330 Statement::IfStatement(if_statement) => if_statement.src,
331 Statement::InlineAssembly(inline_assembly) => inline_assembly.src,
332 Statement::PlaceholderStatement(placeholder_statement) => placeholder_statement.src,
333 Statement::Return(return_stmt) => return_stmt.src,
334 Statement::RevertStatement(revert_statement) => revert_statement.src,
335 Statement::TryStatement(try_statement) => try_statement.src,
336 Statement::UncheckedBlock(unchecked_block) => unchecked_block.src,
337 Statement::VariableDeclarationStatement(variable_declaration_statement) => {
338 variable_declaration_statement.src
339 }
340 Statement::WhileStatement(while_statement) => while_statement.src,
341 }
342}
343
344pub fn block_or_stmt_src(block_or_stmt: &BlockOrStatement) -> SourceLocation {
346 match block_or_stmt {
347 BlockOrStatement::Block(block) => block.src,
348 BlockOrStatement::Statement(statement) => stmt_src(statement),
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 macro_rules! sloc {
357 ($start:expr, $length:expr, $index:expr) => {
358 SourceLocation { start: Some($start), length: Some($length), index: Some($index) }
359 };
360 }
361
362 #[test]
363 fn test_sloc_ldiff() {
364 let a = sloc!(0, 10, 0);
365 let b = sloc!(5, 5, 0);
366 let c = sloc_ldiff(a, b);
367 assert_eq!(c, sloc!(0, 5, 0));
368
369 let a = sloc!(0, 10, 0);
370 let b = sloc!(0, 10, 0);
371 let c = sloc_ldiff(a, b);
372 assert_eq!(c, sloc!(0, 0, 0));
373
374 let a = sloc!(0, 10, 0);
375 let b = sloc!(10, 10, 0);
376 let c = sloc_ldiff(a, b);
377 assert_eq!(c, sloc!(0, 10, 0));
378
379 let a = sloc!(5, 5, 0);
380 let b = sloc!(0, 10, 0);
381 let c = sloc_ldiff(a, b);
382 assert_eq!(c, sloc!(5, 0, 0));
383 }
384
385 #[test]
386 fn test_sloc_rdiff() {
387 let a = sloc!(0, 10, 0);
388 let b = sloc!(5, 5, 0);
389 let c = sloc_rdiff(a, b);
390 assert_eq!(c, sloc!(10, 0, 0));
391
392 let a = sloc!(0, 10, 0);
393 let b = sloc!(0, 10, 0);
394 let c = sloc_rdiff(a, b);
395 assert_eq!(c, sloc!(10, 0, 0));
396
397 let a = sloc!(0, 10, 0);
398 let b = sloc!(0, 5, 0);
399 let c = sloc_rdiff(a, b);
400 assert_eq!(c, sloc!(5, 5, 0));
401
402 let a = sloc!(5, 5, 0);
403 let b = sloc!(0, 10, 0);
404 let c = sloc_rdiff(a, b);
405 assert_eq!(c, sloc!(10, 0, 0));
406 }
407}