edb_engine/analysis/
types.rs1use std::{collections::HashMap, sync::Arc};
18
19use alloy_dyn_abi::DynSolType;
20use foundry_compilers::artifacts::{
21 ContractDefinition, EnumDefinition, Expression, StructDefinition, TypeName,
22 UserDefinedValueTypeDefinition,
23};
24use once_cell::sync::OnceCell;
25use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
26use serde::{Deserialize, Serialize};
27
28use crate::analysis::macros::universal_id;
29
30universal_id! {
31 UTID => 0
33}
34
35#[derive(Debug, Clone)]
37pub struct UserDefinedTypeRef {
38 inner: Arc<RwLock<UserDefinedType>>,
39 utid: OnceCell<UTID>,
41 ast_id: OnceCell<usize>,
43 variant: OnceCell<UserDefinedTypeVariant>,
45}
46
47impl UserDefinedTypeRef {
48 pub fn new(inner: UserDefinedType) -> Self {
50 Self {
51 inner: Arc::new(RwLock::new(inner)),
52 utid: OnceCell::new(),
53 ast_id: OnceCell::new(),
54 variant: OnceCell::new(),
55 }
56 }
57}
58
59impl From<UserDefinedType> for UserDefinedTypeRef {
60 fn from(value: UserDefinedType) -> Self {
61 Self::new(value)
62 }
63}
64
65#[allow(unused)]
66impl UserDefinedTypeRef {
67 pub(crate) fn read(&self) -> RwLockReadGuard<'_, UserDefinedType> {
68 self.inner.read()
69 }
70
71 pub(crate) fn write(&self) -> RwLockWriteGuard<'_, UserDefinedType> {
72 self.inner.write()
73 }
74
75 pub(crate) fn utid(&self) -> UTID {
76 *self.utid.get_or_init(|| self.inner.read().utid)
77 }
78
79 pub(crate) fn ast_id(&self) -> usize {
80 *self.ast_id.get_or_init(|| self.inner.read().variant.ast_id())
81 }
82
83 pub(crate) fn is_typed_address(&self) -> bool {
85 matches!(self.inner.read().variant, UserDefinedTypeVariant::Contract(_))
86 }
87
88 pub(crate) fn variant(&self) -> &UserDefinedTypeVariant {
89 self.variant.get_or_init(|| self.inner.read().variant.clone())
90 }
91}
92
93impl Serialize for UserDefinedTypeRef {
94 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
95 where
96 S: serde::Serializer,
97 {
98 self.inner.read().serialize(serializer)
99 }
100}
101
102impl<'de> Deserialize<'de> for UserDefinedTypeRef {
103 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
104 where
105 D: serde::Deserializer<'de>,
106 {
107 let user_defined_type = UserDefinedType::deserialize(deserializer)?;
108 Ok(user_defined_type.into())
109 }
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct UserDefinedType {
115 pub utid: UTID,
117 pub variant: UserDefinedTypeVariant,
119 pub source_id: u32,
121}
122
123impl UserDefinedType {
124 pub fn new(source_id: u32, variant: UserDefinedTypeVariant) -> Self {
126 Self { utid: UTID::next(), variant, source_id }
127 }
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
132#[allow(clippy::large_enum_variant)]
133pub enum UserDefinedTypeVariant {
134 Struct(StructDefinition),
136 Enum(EnumDefinition),
138 UserDefinedValueType(UserDefinedValueTypeDefinition),
140 Contract(ContractDefinition),
142}
143
144impl UserDefinedTypeVariant {
145 pub fn ast_id(&self) -> usize {
147 match self {
148 Self::Struct(definition) => definition.id,
149 Self::Enum(definition) => definition.id,
150 Self::UserDefinedValueType(definition) => definition.id,
151 Self::Contract(definition) => definition.id,
152 }
153 }
154}
155
156pub fn dyn_sol_type(
165 all_user_defined_types: &HashMap<usize, UserDefinedTypeRef>,
166 type_name: &TypeName,
167) -> Option<DynSolType> {
168 match type_name {
169 TypeName::ArrayTypeName(array_type_name) => {
170 let base = dyn_sol_type(all_user_defined_types, &array_type_name.base_type)?;
171 match array_type_name.length.as_ref() {
172 Some(Expression::Literal(literal)) => {
173 let len = literal.value.as_ref()?;
174 let len = len.parse::<usize>().ok()?;
175 Some(DynSolType::FixedArray(Box::new(base), len))
176 }
177 Some(_) => None,
178 None => Some(DynSolType::Array(Box::new(base))),
179 }
180 }
181 TypeName::ElementaryTypeName(elementary_type_name) => {
182 DynSolType::parse(&elementary_type_name.name).ok()
183 }
184 TypeName::FunctionTypeName(_) => Some(DynSolType::Function),
185 TypeName::Mapping(_) => None,
186 TypeName::UserDefinedTypeName(user_defined_type_name) => {
187 if user_defined_type_name.referenced_declaration < 0 {
188 return None;
189 }
190
191 let ty_def = all_user_defined_types
192 .get(&(user_defined_type_name.referenced_declaration as usize))?;
193
194 match ty_def.variant() {
195 UserDefinedTypeVariant::Struct(definition) => {
196 let mut prop_names = Vec::with_capacity(definition.members.len());
197 let mut prop_types = Vec::with_capacity(definition.members.len());
198 for field in definition.members.iter() {
199 prop_names.push(field.name.clone());
200 prop_types.push(
201 dyn_sol_type(all_user_defined_types, field.type_name.as_ref()?)
202 .unwrap(),
203 );
204 }
205 Some(DynSolType::CustomStruct {
206 name: definition.name.clone(),
207 prop_names,
208 tuple: prop_types,
209 })
210 }
211 UserDefinedTypeVariant::Enum(_) => Some(DynSolType::Uint(8)),
212 UserDefinedTypeVariant::UserDefinedValueType(
213 user_defined_value_type_definition,
214 ) => {
215 let underlying_type = &user_defined_value_type_definition.underlying_type;
216 dyn_sol_type(all_user_defined_types, underlying_type)
217 }
218 UserDefinedTypeVariant::Contract(_) => Some(DynSolType::Address),
219 }
220 }
221 }
222}
223
224#[cfg(test)]
225mod tests {
226
227 use alloy_dyn_abi::DynSolType;
228
229 use crate::analysis::tests::compile_and_analyze;
230
231 use super::*;
232
233 #[test]
234 fn test_parse_struct_as_dyn_sol_type() {
235 let source = r#"
236 contract C {
237 struct MyStruct {
238 uint256 a;
239 uint256 b;
240 }
241 MyStruct internal myStruct;
242 }
243 "#;
244 let (_sources, analysis) = compile_and_analyze(source);
245 let var = analysis.state_variables.first().unwrap();
246 let ty = var.type_name().unwrap();
247 let dyn_ty = dyn_sol_type(&analysis.user_defined_types(), ty).unwrap();
248 assert_eq!(
249 dyn_ty,
250 DynSolType::CustomStruct {
251 name: "MyStruct".to_string(),
252 prop_names: vec!["a".to_string(), "b".to_string()],
253 tuple: vec![DynSolType::Uint(256), DynSolType::Uint(256)],
254 }
255 );
256 }
257}