use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Data, DeriveInput, GenericParam, Type, TypeParamBound};
pub fn input_expr_derive(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);
let name = &ast.ident;
if ast.generics.params.is_empty() {
panic!("InputExpr requires at least one type parameter");
}
let first_param_name = match &ast.generics.params[0] {
GenericParam::Type(ty) => &ty.ident,
_ => panic!("InputExpr requires a type parameter"),
};
let has_sp1_air_builder = match &ast.generics.params[0] {
GenericParam::Type(type_param) => type_param.bounds.iter().any(|bound| {
if let TypeParamBound::Trait(trait_bound) = bound {
trait_bound.path.segments.iter().any(|seg| seg.ident == "SP1AirBuilder")
} else {
false
}
}),
_ => false,
};
if !has_sp1_air_builder {
panic!("InputExpr requires the first type parameter to have SP1AirBuilder bound");
}
let mut type_param_replacements = Vec::new();
for (i, param) in ast.generics.params.iter().enumerate() {
if i == 0 {
if let GenericParam::Type(ty) = param {
type_param_replacements
.push((ty.ident.clone(), quote! { sp1_hypercube::ir::ConstraintCompiler }));
}
} else {
if let GenericParam::Type(type_param) = param {
let has_into_expr = type_param.bounds.iter().any(|bound| {
if let TypeParamBound::Trait(trait_bound) = bound {
if trait_bound.path.segments.len() == 1
&& trait_bound.path.segments[0].ident == "Into"
{
if let syn::PathArguments::AngleBracketed(args) =
&trait_bound.path.segments[0].arguments
{
if args.args.len() == 1 {
if let syn::GenericArgument::Type(Type::Path(type_path)) =
&args.args[0]
{
if type_path.path.segments.len() == 2
&& type_path.path.segments[0].ident == *first_param_name
&& type_path.path.segments[1].ident == "Expr"
{
return true;
}
}
}
}
}
}
false
});
if has_into_expr {
type_param_replacements.push((
type_param.ident.clone(),
quote! { <sp1_hypercube::ir::ConstraintCompiler as slop_air::AirBuilder>::Expr },
));
} else {
panic!(
"Type parameter {} must have bound 'Into<{}::Expr>'",
type_param.ident, first_param_name
);
}
}
}
}
let (field_names, input_exprs, output_exprs): (Vec<_>, Vec<_>, Vec<_>) = match &ast.data {
Data::Struct(data_struct) => {
let items: Vec<_> = data_struct
.fields
.iter()
.filter_map(|field| {
let field_name = field.ident.as_ref()?;
let (input_expr, output_expr) = if let Type::Array(_array_type) = &field.ty {
(
quote! { core::array::from_fn(|_| <sp1_hypercube::ir::ConstraintCompiler as slop_air::AirBuilder>::Expr::input_arg(ctx)) },
quote! { core::array::from_fn(|_| <sp1_hypercube::ir::ConstraintCompiler as slop_air::AirBuilder>::Expr::output_arg(ctx)) }
)
} else if let Type::Path(type_path) = &field.ty {
if type_path.path.segments.len() == 2 {
let first_seg = &type_path.path.segments[0];
let second_seg = &type_path.path.segments[1];
if first_seg.ident == *first_param_name && (second_seg.ident == "Expr" || second_seg.ident == "Var") {
(
quote! { <sp1_hypercube::ir::ConstraintCompiler as slop_air::AirBuilder>::Expr::input_arg(ctx) },
quote! { <sp1_hypercube::ir::ConstraintCompiler as slop_air::AirBuilder>::Expr::output_arg(ctx) }
)
} else {
(
quote! { <sp1_hypercube::ir::ConstraintCompiler as slop_air::AirBuilder>::Expr::input_from_struct(ctx) },
quote! { <sp1_hypercube::ir::ConstraintCompiler as slop_air::AirBuilder>::Expr::output_from_struct(ctx) }
)
}
} else {
(
quote! { <sp1_hypercube::ir::ConstraintCompiler as slop_air::AirBuilder>::Expr::input_from_struct(ctx) },
quote! { <sp1_hypercube::ir::ConstraintCompiler as slop_air::AirBuilder>::Expr::output_from_struct(ctx) }
)
}
} else {
(
quote! { <sp1_hypercube::ir::ConstraintCompiler as slop_air::AirBuilder>::Expr::input_from_struct(ctx) },
quote! { <sp1_hypercube::ir::ConstraintCompiler as slop_air::AirBuilder>::Expr::output_from_struct(ctx) }
)
};
Some((field_name.clone(), input_expr, output_expr))
})
.collect();
let mut names = Vec::new();
let mut inputs = Vec::new();
let mut outputs = Vec::new();
for (n, i, o) in items {
names.push(n);
inputs.push(i);
outputs.push(o);
}
(names, inputs, outputs)
}
_ => panic!("InputExpr can only be derived for structs"),
};
let input_constructor_call = if field_names.is_empty() {
quote! { #name::new() }
} else {
quote! { #name::new(#(#input_exprs),*) }
};
let output_constructor_call = if field_names.is_empty() {
quote! { #name::new() }
} else {
quote! { #name::new(#(#output_exprs),*) }
};
let concrete_types: Vec<_> = type_param_replacements.iter().map(|(_, ty)| ty).collect();
let expanded = quote! {
impl #name<#(#concrete_types),*> {
fn to_input(&self, ctx: &mut sp1_hypercube::ir::FuncCtx) -> #name<#(#concrete_types),*> {
#input_constructor_call
}
fn to_output(&self, ctx: &mut sp1_hypercube::ir::FuncCtx) -> #name<#(#concrete_types),*> {
#output_constructor_call
}
}
};
TokenStream::from(expanded)
}