use crate::InternalStructs;
use ethers_core::abi::{
struct_def::{FieldType, StructFieldType},
ParamType, SolStruct,
};
use eyre::Result;
use inflector::Inflector;
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;
use std::path::{Path, PathBuf};
pub(crate) fn ident(name: &str) -> Ident {
Ident::new(name, Span::call_site())
}
pub(crate) fn safe_ident(name: &str) -> Ident {
syn::parse_str::<Ident>(name).unwrap_or_else(|_| ident(&format!("{name}_")))
}
pub(crate) fn safe_snake_case(ident: &str) -> String {
safe_identifier_name(ident.to_snake_case())
}
pub(crate) fn safe_pascal_case(ident: &str) -> String {
safe_identifier_name(ident.to_pascal_case())
}
pub(crate) fn safe_identifier_name(name: String) -> String {
if name.starts_with(char::is_numeric) {
format!("_{name}")
} else {
name
}
}
pub(crate) fn safe_module_name(name: &str) -> String {
safe_ident(&safe_snake_case(name)).to_string()
}
pub(crate) fn safe_snake_case_ident(name: &str) -> Ident {
let i = name.to_snake_case();
ident(&preserve_underscore_delim(&i, name))
}
pub(crate) fn safe_pascal_case_ident(name: &str) -> Ident {
let i = name.to_pascal_case();
ident(&preserve_underscore_delim(&i, name))
}
pub(crate) fn preserve_underscore_delim(ident: &str, original: &str) -> String {
let is_underscore = |c: &char| *c == '_';
let pre = original.chars().take_while(is_underscore);
let post = original.chars().rev().take_while(is_underscore);
pre.chain(ident.chars()).chain(post).collect()
}
pub(crate) fn expand_input_name(index: usize, name: &str) -> TokenStream {
let name_str = match name {
"" => format!("p{index}"),
n => n.to_snake_case(),
};
let name = safe_ident(&name_str);
quote! { #name }
}
#[cfg(all(feature = "online", not(target_arch = "wasm32")))]
pub(crate) fn http_get(url: impl reqwest::IntoUrl) -> Result<String> {
Ok(reqwest::blocking::get(url)?.text()?)
}
pub(crate) fn resolve_path(raw: &str) -> Result<PathBuf> {
let mut unprocessed = raw;
let mut resolved = String::new();
while let Some(dollar_sign) = unprocessed.find('$') {
let (head, tail) = unprocessed.split_at(dollar_sign);
resolved.push_str(head);
match parse_identifier(&tail[1..]) {
Some((variable, rest)) => {
let value = std::env::var(variable)?;
resolved.push_str(&value);
unprocessed = rest;
}
None => {
eyre::bail!("Unable to parse a variable from \"{tail}\"")
}
}
}
resolved.push_str(unprocessed);
Ok(PathBuf::from(resolved))
}
fn parse_identifier(text: &str) -> Option<(&str, &str)> {
let mut calls = 0;
let (head, tail) = take_while(text, |c| {
calls += 1;
match c {
'_' => true,
letter if letter.is_ascii_alphabetic() => true,
digit if digit.is_ascii_digit() && calls > 1 => true,
_ => false,
}
});
if head.is_empty() {
None
} else {
Some((head, tail))
}
}
fn take_while(s: &str, mut predicate: impl FnMut(char) -> bool) -> (&str, &str) {
let mut index = 0;
for c in s.chars() {
if predicate(c) {
index += c.len_utf8();
} else {
break
}
}
s.split_at(index)
}
pub(crate) fn json_files(root: impl AsRef<Path>) -> Vec<PathBuf> {
walkdir::WalkDir::new(root)
.into_iter()
.filter_map(Result::ok)
.filter(|e| e.file_type().is_file())
.filter(|e| e.path().extension().map(|ext| ext == "json").unwrap_or_default())
.map(|e| e.path().into())
.collect()
}
pub(crate) fn derive_builtin_traits<'a>(
params: impl IntoIterator<Item = &'a ParamType>,
stream: &mut TokenStream,
mut derive_default: bool,
mut derive_others: bool,
) {
for param in params {
derive_default = derive_default && can_derive_default(param);
derive_others = derive_others && can_derive_builtin_traits(param);
}
extend_derives(stream, derive_default, derive_others);
}
pub(crate) fn derive_builtin_traits_struct(
structs: &InternalStructs,
sol_struct: &SolStruct,
params: &[ParamType],
stream: &mut TokenStream,
) {
if sol_struct.fields().iter().any(|field| field.ty.is_struct()) {
let mut def = true;
let mut others = true;
_derive_builtin_traits_struct(structs, sol_struct, params, &mut def, &mut others);
extend_derives(stream, def, others);
} else {
derive_builtin_traits(params, stream, true, true);
}
}
fn _derive_builtin_traits_struct(
structs: &InternalStructs,
sol_struct: &SolStruct,
params: &[ParamType],
def: &mut bool,
others: &mut bool,
) {
let fields = sol_struct.fields();
debug_assert_eq!(fields.len(), params.len());
for (field, ty) in fields.iter().zip(params) {
match &field.ty {
FieldType::Struct(s_ty) => {
if let StructFieldType::FixedArray(_, len) = s_ty {
*def &= *len <= MAX_SUPPORTED_ARRAY_LEN;
}
let id = s_ty.identifier();
if let Some(recursed_struct) = structs.structs.get(&id) {
let recursed_params = get_struct_params(s_ty, ty);
_derive_builtin_traits_struct(
structs,
recursed_struct,
recursed_params,
def,
others,
);
}
}
FieldType::Elementary(ty1) => {
debug_assert_eq!(ty, ty1);
*def = *def && can_derive_default(ty);
*others = *others && can_derive_builtin_traits(ty);
}
FieldType::Mapping(_) => unreachable!(),
}
}
}
fn get_struct_params<'a>(s_ty: &StructFieldType, ty: &'a ParamType) -> &'a [ParamType] {
match (s_ty, ty) {
(_, ParamType::Tuple(params)) => params,
(
StructFieldType::Array(s_ty) | StructFieldType::FixedArray(s_ty, _),
ParamType::Array(param) | ParamType::FixedArray(param, _),
) => get_struct_params(s_ty, param),
_ => unreachable!("Unhandled struct field: {s_ty:?} | {ty:?}"),
}
}
fn extend_derives(stream: &mut TokenStream, def: bool, others: bool) {
if def {
stream.extend(quote!(Default,))
}
if others {
stream.extend(quote!(Debug, PartialEq, Eq, Hash))
}
}
const MAX_SUPPORTED_ARRAY_LEN: usize = 32;
const MAX_SUPPORTED_TUPLE_LEN: usize = 12;
fn can_derive_default(param: &ParamType) -> bool {
match param {
ParamType::Array(ty) => can_derive_default(ty),
ParamType::FixedBytes(len) => *len <= MAX_SUPPORTED_ARRAY_LEN,
ParamType::FixedArray(ty, len) => {
if *len > MAX_SUPPORTED_ARRAY_LEN {
false
} else {
can_derive_default(ty)
}
}
ParamType::Tuple(params) => {
if params.len() > MAX_SUPPORTED_TUPLE_LEN {
false
} else {
params.iter().all(can_derive_default)
}
}
_ => true,
}
}
fn can_derive_builtin_traits(param: &ParamType) -> bool {
match param {
ParamType::Array(ty) | ParamType::FixedArray(ty, _) => can_derive_builtin_traits(ty),
ParamType::Tuple(params) => {
if params.len() > MAX_SUPPORTED_TUPLE_LEN {
false
} else {
params.iter().all(can_derive_builtin_traits)
}
}
_ => true,
}
}
pub(crate) fn abi_signature<'a, N, T>(name: N, types: T) -> String
where
N: std::fmt::Display,
T: IntoIterator<Item = &'a ParamType>,
{
let types = abi_signature_types(types);
format!("`{name}({types})`")
}
pub(crate) fn abi_signature_types<'a, T: IntoIterator<Item = &'a ParamType>>(types: T) -> String {
types.into_iter().map(ToString::to_string).collect::<Vec<_>>().join(",")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn can_detect_derives() {
let param = ParamType::FixedArray(Box::new(ParamType::Uint(256)), 32);
assert!(can_derive_default(¶m));
assert!(can_derive_builtin_traits(¶m));
let param = ParamType::FixedArray(Box::new(ParamType::Uint(256)), 33);
assert!(!can_derive_default(¶m));
assert!(can_derive_builtin_traits(¶m));
let param = ParamType::Tuple(vec![ParamType::Uint(256); 12]);
assert!(can_derive_default(¶m));
assert!(can_derive_builtin_traits(¶m));
let param = ParamType::Tuple(vec![ParamType::Uint(256); 13]);
assert!(!can_derive_default(¶m));
assert!(!can_derive_builtin_traits(¶m));
}
#[test]
fn can_resolve_path() {
let raw = "./$ENV_VAR";
std::env::set_var("ENV_VAR", "file.txt");
let resolved = resolve_path(raw).unwrap();
assert_eq!(resolved.to_str().unwrap(), "./file.txt");
}
#[test]
fn input_name_to_ident_empty() {
assert_quote!(expand_input_name(0, ""), { p0 });
}
#[test]
fn input_name_to_ident_keyword() {
assert_quote!(expand_input_name(0, "self"), { self_ });
}
#[test]
fn input_name_to_ident_snake_case() {
assert_quote!(expand_input_name(0, "CamelCase1"), { camel_case_1 });
}
#[test]
fn test_safe_module_name() {
assert_eq!(safe_module_name("Valid"), "valid");
assert_eq!(safe_module_name("Enum"), "enum_");
assert_eq!(safe_module_name("Mod"), "mod_");
assert_eq!(safe_module_name("2Two"), "_2_two");
}
}