switchboard_evm/bindings/
eip712.rs

1use super::switchboard;
2use ethers::core::{
3    abi::ParamType,
4    types::{
5        transaction::eip712::{make_type_hash, EIP712Domain},
6        Address, Bytes, U256,
7    },
8};
9use inflector::Inflector;
10use std::str::FromStr;
11use switchboard_common::EvmTransaction;
12use syn::{
13    parse::Error, spanned::Spanned, Data, Expr, Fields, GenericArgument, Lit, PathArguments,
14    Result as SynResult, Type,
15};
16
17///`Transaction(uint256,uint256,uint256,address,address,bytes)`
18#[derive(
19    Clone,
20    ::ethers::contract::EthAbiType,
21    ::ethers::contract::EthAbiCodec,
22    Default,
23    Debug,
24    PartialEq,
25    Eq,
26    Hash,
27)]
28pub struct Transaction {
29    pub expiration_time_seconds: ::ethers::core::types::U256,
30    pub gas_limit: ::ethers::core::types::U256,
31    pub value: ::ethers::core::types::U256,
32    pub to: ::ethers::core::types::Address,
33    pub from: ::ethers::core::types::Address,
34    pub data: ::ethers::core::types::Bytes,
35}
36
37type Eip712Error = ethers::core::types::transaction::eip712::Eip712Error;
38
39impl Transaction {
40    fn type_hash(&self) -> ::core::result::Result<[u8; 32], Eip712Error> {
41        let input: ::syn::DeriveInput = ::syn::parse_quote! { struct Transaction {
42            pub expiration_time_seconds: ::ethers::core::types::U256,
43            pub gas_limit: ::ethers::core::types::U256,
44            pub value: ::ethers::core::types::U256,
45            pub to: ::ethers::core::types::Address,
46            pub from: ::ethers::core::types::Address,
47            pub data: ::ethers::core::types::Bytes,
48        } };
49        let primary_type = input.clone().ident;
50        let parsed_fields = parse_fields(&input).unwrap();
51        let type_hash = make_type_hash(primary_type.to_string(), &parsed_fields);
52        Ok(type_hash)
53    }
54
55    #[inline]
56    fn domain_separator(
57        &self,
58        domain: ethers::core::types::transaction::eip712::EIP712Domain,
59    ) -> ::core::result::Result<[u8; 32], Eip712Error> {
60        let domain_separator = domain.separator();
61
62        // @NOTE Anoushk - check this
63        let _domain_str = serde_json::to_string(&domain).unwrap();
64        Ok(domain_separator)
65    }
66
67    fn struct_hash(&self) -> ::core::result::Result<[u8; 32], Eip712Error> {
68        let mut items = vec![ethers::core::abi::Token::Uint(
69            ethers::core::types::U256::from(&Self::type_hash(&self)?[..]),
70        )];
71
72        if let ethers::core::abi::Token::Tuple(tokens) =
73            ethers::core::abi::Tokenizable::into_token(::core::clone::Clone::clone(self))
74        {
75            items.reserve(tokens.len());
76            for token in tokens {
77                match &token {
78                    ethers::core::abi::Token::Tuple(_t) => {
79                        // TODO: check for nested Eip712 Type;
80                        // Challenge is determining the type hash
81                        return Err(Eip712Error::NestedEip712StructNotImplemented);
82                    }
83                    _ => {
84                        items.push(
85                            ethers::core::types::transaction::eip712::encode_eip712_type(token),
86                        );
87                    }
88                }
89            }
90        }
91
92        let struct_hash = ethers::core::utils::keccak256(ethers::core::abi::encode(&items));
93
94        Ok(struct_hash)
95    }
96
97    pub fn encode_eip712(
98        &self,
99        domain: ethers::core::types::transaction::eip712::EIP712Domain,
100    ) -> std::result::Result<[u8; 32], Eip712Error> {
101        // encode the digest to be compatible with solidity abi.encodePacked()
102        // See: https://github.com/gakonst/ethers-rs/blob/master/examples/permit_hash.rs#L72
103
104        let domain_separator = self.domain_separator(domain)?;
105        let struct_hash = self.struct_hash()?;
106
107        let digest_input = [&[0x19, 0x01], &domain_separator[..], &struct_hash[..]].concat();
108
109        Ok(ethers::core::utils::keccak256(digest_input))
110    }
111}
112
113pub fn parse_fields(
114    input: &syn::DeriveInput,
115) -> SynResult<Vec<(String, ethers::core::abi::ParamType)>> {
116    let data = match &input.data {
117        Data::Struct(s) => s,
118        Data::Enum(e) => {
119            return Err(Error::new(
120                e.enum_token.span,
121                "Eip712 is not derivable for enums",
122            ))
123        }
124        Data::Union(u) => {
125            return Err(Error::new(
126                u.union_token.span,
127                "Eip712 is not derivable for unions",
128            ))
129        }
130    };
131
132    let named_fields = match &data.fields {
133        Fields::Named(fields) => fields,
134        _ => return Err(Error::new(input.span(), "unnamed fields are not supported")),
135    };
136
137    let mut fields = Vec::with_capacity(named_fields.named.len());
138    for f in named_fields.named.iter() {
139        // strip the raw identifier prefix
140        let name = f.ident.as_ref().unwrap().to_string();
141        let s = name.strip_prefix("r#").unwrap_or(&name);
142        let name = s.to_camel_case();
143
144        let ty = match f
145            .attrs
146            .iter()
147            .find(|a| a.path().segments.iter().any(|s| s.ident == "eip712"))
148        {
149            // Found nested Eip712 Struct
150            // TODO: Implement custom
151            Some(a) => {
152                return Err(Error::new(
153                    a.span(),
154                    "nested Eip712 struct are not yet supported",
155                ))
156            }
157            // Not a nested eip712 struct, return the field param type;
158            None => find_parameter_type(&f.ty)?,
159        };
160
161        fields.push((name, ty));
162    }
163
164    Ok(fields)
165}
166pub fn find_parameter_type(ty: &Type) -> core::result::Result<ParamType, Error> {
167    const ERROR: &str = "Failed to derive proper ABI from array field";
168
169    match ty {
170        Type::Array(arr) => {
171            let ty = find_parameter_type(&arr.elem)?;
172            if let Expr::Lit(ref expr) = arr.len {
173                if let Lit::Int(ref len) = expr.lit {
174                    if let Ok(len) = len.base10_parse::<usize>() {
175                        return match (ty, len) {
176                            (ParamType::Uint(8), 32) => Ok(ParamType::FixedBytes(32)),
177                            (ty, len) => Ok(ParamType::FixedArray(Box::new(ty), len)),
178                        };
179                    }
180                }
181            }
182            Err(Error::new(arr.span(), ERROR))
183        }
184
185        Type::Path(ty) => {
186            // check for `Vec`
187            if let Some(segment) = ty.path.segments.iter().find(|s| s.ident == "Vec") {
188                if let PathArguments::AngleBracketed(ref args) = segment.arguments {
189                    // Vec<T, A?>
190                    debug_assert!(matches!(args.args.len(), 1 | 2));
191                    let ty = args.args.iter().next().unwrap();
192                    if let GenericArgument::Type(ref ty) = ty {
193                        return find_parameter_type(ty)
194                            .map(|kind| ParamType::Array(Box::new(kind)));
195                    }
196                }
197            }
198
199            // match on the last segment of the path
200            ty.path
201                .get_ident()
202                .or_else(|| ty.path.segments.last().map(|s| &s.ident))
203                .and_then(|ident| {
204                    match ident.to_string().as_str() {
205                        // eth types
206                        "Address" => Some(ParamType::Address),
207                        "Bytes" => Some(ParamType::Bytes),
208                        "Uint8" => Some(ParamType::Uint(8)),
209
210                        // core types
211                        "String" => Some(ParamType::String),
212                        "bool" => Some(ParamType::Bool),
213                        // usize / isize, shouldn't happen but use max width
214                        "usize" => Some(ParamType::Uint(64)),
215                        "isize" => Some(ParamType::Int(64)),
216
217                        s => parse_param_type(s),
218                    }
219                })
220                .ok_or_else(|| Error::new(ty.span(), ERROR))
221        }
222
223        Type::Tuple(ty) => ty
224            .elems
225            .iter()
226            .map(find_parameter_type)
227            .collect::<core::result::Result<Vec<_>, _>>()
228            .map(ParamType::Tuple),
229
230        _ => Err(Error::new(ty.span(), ERROR)),
231    }
232}
233pub fn parse_param_type(s: &str) -> Option<ParamType> {
234    match s.chars().next() {
235        Some('H' | 'h') => {
236            let size = s[1..].parse::<usize>().ok()? / 8;
237            Some(ParamType::FixedBytes(size))
238        }
239
240        Some(c @ 'U' | c @ 'I' | c @ 'u' | c @ 'i') => {
241            let size = s[1..].parse::<usize>().ok()?;
242            if matches!(c, 'U' | 'u') {
243                Some(ParamType::Uint(size))
244            } else {
245                Some(ParamType::Int(size))
246            }
247        }
248        _ => None,
249    }
250}
251
252impl From<&EvmTransaction> for Transaction {
253    fn from(tx: &EvmTransaction) -> Self {
254        Transaction {
255            expiration_time_seconds: U256::from(tx.expiration_time_seconds),
256            gas_limit: U256::from_str_radix(&tx.gas_limit, 10).unwrap(),
257            value: U256::from_str_radix(&tx.value, 10).unwrap(),
258            to: Address::from_slice(&tx.to),
259            from: Address::from_slice(&tx.from),
260            data: tx.data.clone().into(),
261        }
262    }
263}
264
265impl From<&switchboard::Transaction> for Transaction {
266    fn from(tx: &switchboard::Transaction) -> Self {
267        Transaction {
268            expiration_time_seconds: tx.expiration_time_seconds,
269            gas_limit: tx.gas_limit,
270            value: tx.value,
271            to: tx.to,
272            from: tx.from,
273            data: tx.data.clone(),
274        }
275    }
276}
277
278pub fn get_transaction_hash(
279    name: String,
280    version: String,
281    chain_id: u64,
282    verifying_contract: Address,
283    transaction: switchboard::Transaction,
284) -> std::result::Result<[u8; 32], Eip712Error> {
285    let tx = Transaction::from(&transaction);
286    let domain = EIP712Domain {
287        name: Some(name.into()),
288        version: Some(version.into()),
289        chain_id: Some(chain_id.into()),
290        verifying_contract: Some(verifying_contract),
291        salt: None,
292    };
293
294    // encode hash
295    tx.encode_eip712(domain.clone())
296}