use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{parse_macro_input, Attribute, Data, DeriveInput, GenericParam, Ident, TypeParamBound};
pub fn input_params_derive(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);
let name = &ast.ident;
let (fields, field_entries) = match &ast.data {
Data::Struct(data_struct) => data_struct
.fields
.iter()
.filter_map(|field| {
let field_name = field.ident.as_ref()?;
let field_name_str = field_name.to_string();
let attribute = parse_picus_attributes(&field.attrs);
Some((
field.clone(),
quote! {
(#field_name_str.to_string(), #attribute, self.#field_name.into())
},
))
})
.unzip::<_, _, Vec<_>, Vec<_>>(),
_ => panic!("InputParams can only be derived for structs"),
};
let field_names = fields
.iter()
.map(|field| field.ident.clone().expect("Field should be named."))
.collect::<Vec<_>>();
let first_param_name = match ast.generics.params.first() {
Some(GenericParam::Type(ty)) => Some(&ty.ident),
_ => None,
};
let has_sp1_air_builder = ast.generics.params.first().is_some_and(|param| {
if let GenericParam::Type(type_param) = param {
type_param.bounds.iter().any(|bound| {
if let TypeParamBound::Trait(trait_bound) = bound {
trait_bound.path.segments.iter().any(|seg| seg.ident == "SP1AirBuilder")
} else {
false
}
})
} else {
false
}
});
let expanded = if has_sp1_air_builder {
let num_params = ast.generics.params.len();
let first_param_name = first_param_name.expect("First type parameter should be named.");
let field_type_params = fields
.iter()
.map(|field| {
let name = field.ident.as_ref().expect("Field should be named.").clone();
let ty_of = &field.ty;
quote! { #name: #ty_of }
})
.collect::<Vec<_>>();
if num_params == 1 {
quote! {
impl<#first_param_name: SP1AirBuilder> #name<#first_param_name> {
#[allow(clippy::too_many_arguments)]
pub const fn new(#(#field_type_params),*) -> Self {
Self {
#(#field_names),*
}
}
}
impl #name<sp1_hypercube::ir::ConstraintCompiler> {
fn params_vec(
self,
) -> Vec<(
String,
sp1_hypercube::ir::Attribute,
sp1_hypercube::ir::Shape<
<sp1_hypercube::ir::ConstraintCompiler as slop_air::AirBuilder>::Expr,
<sp1_hypercube::ir::ConstraintCompiler as slop_air::ExtensionBuilder>::ExprEF,
>,
)> {
vec![
#(#field_entries,)*
]
}
}
}
} else {
let remaining_params_with_constraint_compiler = replace_bounds(
ast.generics.params.iter().skip(1),
first_param_name.clone(),
"< sp1_hypercube :: ir :: ConstraintCompiler as slop_air :: AirBuilder >",
);
let remaining_params = ast.generics.params.iter().skip(1);
let type_args = ast.generics.params.iter().skip(1).filter_map(|param| {
if let GenericParam::Type(type_param) = param {
let ident = &type_param.ident;
Some(quote! { #ident })
} else {
None
}
});
let type_args_clone = type_args.clone();
quote! {
impl<#first_param_name: SP1AirBuilder, #(#remaining_params),*> #name<#first_param_name, #(#type_args_clone),*> {
#[allow(clippy::too_many_arguments)]
pub const fn new(#(#field_type_params),*) -> Self {
Self {
#(#field_names),*
}
}
}
impl<#(#remaining_params_with_constraint_compiler),*> #name<sp1_hypercube::ir::ConstraintCompiler, #(#type_args),*> {
fn params_vec(
self,
) -> Vec<(
String,
sp1_hypercube::ir::Attribute,
sp1_hypercube::ir::Shape<
<sp1_hypercube::ir::ConstraintCompiler as slop_air::AirBuilder>::Expr,
<sp1_hypercube::ir::ConstraintCompiler as slop_air::ExtensionBuilder>::ExprEF,
>,
)> {
vec![
#(#field_entries,)*
]
}
}
}
}
} else {
panic!("InputParams requires the first type parameter to have SP1AirBuilder bound");
};
TokenStream::from(expanded)
}
fn replace_bounds<'a, I>(bounds: I, target: Ident, replacement: &'a str) -> Vec<TokenStream2>
where
I: Iterator<Item = &'a GenericParam>,
{
bounds
.map(move |bound| {
if let GenericParam::Type(type_param) = bound {
let ident = &type_param.ident;
let bounds = &type_param.bounds;
let bounds_str: String = quote! { #bounds }.to_string();
let target_pattern = format!("{target}");
let new_bounds_str = bounds_str.replace(&target_pattern, replacement);
let new_bounds =
syn::parse_str::<syn::TypeParam>(&format!("{ident}: {new_bounds_str}"))
.unwrap_or_else(|_| type_param.clone());
quote! { #new_bounds }
} else {
quote! { #bound }
}
})
.collect()
}
fn parse_picus_attributes(attrs: &[Attribute]) -> TokenStream2 {
for attr in attrs {
if attr.path.is_ident("picus") {
if let Ok(syn::Meta::List(meta_list)) = attr.parse_meta() {
for nested in meta_list.nested {
if let syn::NestedMeta::Meta(syn::Meta::Path(path)) = nested {
if path.is_ident("input") {
return quote! {
sp1_hypercube::ir::Attribute {
picus: sp1_hypercube::ir::PicusArg::Input,
}
};
} else if path.is_ident("output") {
return quote! {
sp1_hypercube::ir::Attribute {
picus: sp1_hypercube::ir::PicusArg::Output,
}
};
}
}
}
}
}
}
quote! { sp1_hypercube::ir::Attribute::default() }
}