use crate::{
expr_default_impl::ExprDefaultImpl,
fold::*,
utils::{has_unique_elements, ToSnakeCase},
};
use proc_macro2::{Ident, TokenStream};
use quote::{format_ident, quote, ToTokens};
use std::{
collections::HashMap,
fmt::{Display, Formatter, Write},
};
use syn::{fold::Fold, Data, DeriveInput, Fields, Generics, PathArguments, Type, Variant};
type FxIndexSet<T> =
indexmap::set::IndexSet<T, std::hash::BuildHasherDefault<rustc_hash::FxHasher>>;
struct VariantDescription {
name: Ident,
types: Vec<Type>,
}
impl VariantDescription {
fn new(variant: &Variant) -> VariantDescription {
let types = match &variant.fields {
Fields::Named(_) => {
panic!("Variants with named fields are not supported")
}
Fields::Unnamed(f) => f.unnamed.iter().map(|x| x.ty.clone()).collect(),
Fields::Unit => Vec::new(),
};
VariantDescription {
name: variant.ident.clone(),
types,
}
}
}
struct EnumDescription {
name: Ident,
generic_names: Vec<Ident>,
generics: Generics,
variants: Vec<VariantDescription>,
types: FxIndexSet<Type>,
}
impl EnumDescription {
fn new(input: DeriveInput) -> EnumDescription {
let variants: Vec<VariantDescription> = match &input.data {
Data::Struct(_) => {
panic!("Structs are not supported")
}
Data::Enum(data) => data.variants.iter().map(VariantDescription::new).collect(),
Data::Union(_) => {
panic!("Unions are not supported")
}
};
let generics = input.generics;
let generic_names = generics
.type_params()
.flat_map(|x| {
if x.bounds.to_token_stream().to_string() == "Clone" {
Some(x.ident.clone())
} else {
None
}
})
.collect::<Vec<_>>();
let mut types = FxIndexSet::default();
for variant in &variants {
types.extend(variant.types.clone());
}
for generic_name in &generic_names {
types.insert(ident_to_type(generic_name));
}
EnumDescription {
name: input.ident,
generic_names,
generics,
variants,
types,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
enum ProducedTrait {
Apply,
Gen,
Proc,
}
impl ProducedTrait {
fn to_verb(self) -> &'static str {
match self {
ProducedTrait::Apply => "apply this map to",
ProducedTrait::Gen => "generate",
ProducedTrait::Proc => "process",
}
}
}
impl Display for ProducedTrait {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
ProducedTrait::Apply => std::fmt::Display::fmt(&"Apply", f),
ProducedTrait::Gen => std::fmt::Display::fmt(&"Gen", f),
ProducedTrait::Proc => std::fmt::Display::fmt(&"Proc", f),
}
}
}
fn use_slice_instead_of_vec(ty: &Type) -> TokenStream {
if let Type::Path(ty) = ty {
if ty.path.segments[0].ident == "Vec" {
let PathArguments::AngleBracketed(vec_args) = &ty.path.segments[0].arguments else {
panic!("Vec should have angle brackets.")
};
let args = &vec_args.args;
return quote! {[#args]};
}
}
ty.to_token_stream()
}
fn fn_name_by_type(input: &EnumDescription) -> HashMap<Type, Ident> {
let fn_name_by_type: HashMap<Type, Ident> = input
.types
.iter()
.map(|ty| {
(
ty.clone(),
format_ident!("{}", ty.to_token_stream().to_string().to_snake_case()),
)
})
.collect();
let are_all_fn_names_unique = has_unique_elements(fn_name_by_type.values());
if !are_all_fn_names_unique {
let display: String = fn_name_by_type
.iter()
.fold(String::new(), |mut output, (k, v)| {
let _ = write!(output, "\n{}: {}", k.to_token_stream(), v);
output
});
panic!("All function names are not unique: {}", display);
}
fn_name_by_type
}
impl ProducedTrait {
fn trait_name(self, input: &EnumDescription) -> Ident {
format_ident!("{self}{}", input.name)
}
fn produce(
self,
input: &EnumDescription,
fn_name_by_type: &HashMap<Type, Ident>,
) -> TokenStream {
let trait_name = self.trait_name(input);
let input_name = &input.name;
let hash_map = input
.generic_names
.iter()
.map(|x| (x.clone(), format_ident!("New{x}")))
.collect::<HashMap<_, _>>();
let trait_generics = if self == ProducedTrait::Apply {
let mut dup = GenericDuplication::new(&hash_map);
dup.fold_generics(input.generics.clone())
} else {
input.generics.clone()
};
let functions_to_implement: TokenStream = input
.types
.iter()
.map(|ty| {
let fn_name = fn_name_by_type.get(ty).unwrap();
let doc = format!(
"How to {} all the elements of type {}.",
self.to_verb(),
ty.to_token_stream()
);
match self {
ProducedTrait::Apply => {
let local_val = format_ident!("val");
let mut replacer = TypeReplacer::new(&hash_map);
let output_type = replacer.fold_type(ty.clone());
let default_inner = ty.expr_default_impl(
&hash_map,
fn_name_by_type,
false,
true,
true,
local_val.to_token_stream(),
);
let default_inner = default_inner
.map(|inner| quote! {{#inner}})
.unwrap_or(quote! {;});
quote! {
#[doc = #doc]
fn #fn_name(&mut self, #local_val: #ty) -> #output_type #default_inner
}
}
ProducedTrait::Gen => {
quote! {
#[doc = #doc]
fn #fn_name(&mut self) -> #ty;
}
}
ProducedTrait::Proc => {
let local_val = format_ident!("val");
let default_inner = ty.expr_default_impl(
&hash_map,
fn_name_by_type,
false,
true,
false,
local_val.to_token_stream(),
);
let default_inner = default_inner
.map(|inner| quote! {{#inner;}})
.unwrap_or(quote! {;});
let ty = use_slice_instead_of_vec(ty);
quote! {
#[doc = #doc]
fn #fn_name(&mut self, val: &#ty) #default_inner
}
}
}
})
.collect();
let inner_match: TokenStream = input
.variants
.iter()
.enumerate()
.map(|(idx, x)| {
let variant_name = &x.name;
let idents = (0..x.types.len())
.map(|i| format_ident!("n{i}"))
.collect::<Vec<_>>();
let inner_variant: TokenStream = idents
.iter()
.enumerate()
.map(|(i, ident)| {
if i == 0 {
quote! {#ident}
} else {
quote! {, #ident}
}
})
.collect();
let f_calls: TokenStream = x
.types
.iter()
.enumerate()
.map(|(i, ty)| {
let local_ident = &idents[i];
let fn_name = fn_name_by_type.get(ty).unwrap();
match self {
ProducedTrait::Apply => {
quote! {let #local_ident = self.#fn_name(#local_ident);}
}
ProducedTrait::Gen => quote! {let #local_ident = self.#fn_name();},
ProducedTrait::Proc => quote! {self.#fn_name(#local_ident);},
}
})
.collect();
let start = if self == ProducedTrait::Gen {
quote! {#idx}
} else {
quote! {#input_name::#variant_name(#inner_variant)}
};
let finish = if self == ProducedTrait::Proc {
quote! {}
} else {
quote! {#input_name::#variant_name(#inner_variant)}
};
quote! {
#start => {
#f_calls
#finish
}
}
})
.collect();
let input_generics = input.generics.split_for_impl().1;
let final_function_doc = format!(
"A function to {} elements of type {}.",
self.to_verb(),
input_name,
);
let final_function = match self {
ProducedTrait::Apply => {
let mut replacer = TypeReplacer::new(&hash_map);
let replaced = replacer.fold_generics(input.generics.clone());
let output_generics = replaced.split_for_impl().1;
quote! {
fn apply(&mut self, expr: #input_name #input_generics) -> #input_name #output_generics {
match expr {
#inner_match
}
}
}
}
ProducedTrait::Gen => {
let n_variants = input.variants.len();
quote! {
fn r#gen(&mut self) -> #input_name #input_generics {
let n = self.choose_variant(#n_variants);
match n {
#inner_match
_ => {panic!("not enough variants")}
}
}
#[doc="Chooses which variant to generate."]
fn choose_variant(&mut self, n_variants: usize) -> usize;
}
}
ProducedTrait::Proc => {
quote! {
fn proc(&mut self, expr: &#input_name #input_generics) {
match expr {
#inner_match
}
}
}
}
};
let doc = format!(
"A trait to {} elements of type {}.",
self.to_verb(),
input_name,
);
let extra_impls = {
let impl_generics = {
let mut generic_remover = GenericRemoval::new(&input.generic_names);
generic_remover.fold_generics(input.generics.clone()).params
};
let input_ty = &input.generic_names[0];
let replace_hash_map: HashMap<Ident, Ident> = input
.generic_names
.iter()
.map(|x| (x.clone(), input_ty.clone()))
.collect();
let replaced_generics =
TypeReplacer::new(&replace_hash_map).fold_generics(input.generics.clone());
match self {
ProducedTrait::Apply => {
let output_ty = hash_map.get(input_ty).unwrap();
let type_generics =
GenericDuplication::new(&hash_map).fold_generics(replaced_generics);
let ty_generics = type_generics.split_for_impl().1;
let implemented_functions: TokenStream = input
.generic_names
.iter()
.map(|x| {
let ty_key = ident_to_type(x);
let fn_name = fn_name_by_type.get(&ty_key).unwrap();
quote! {
fn #fn_name(&mut self, val: #input_ty) -> #output_ty {
(self.func)(val)
}
}
})
.collect();
quote! {
impl <#input_ty: Clone, #output_ty: Clone, FN: FnMut(#input_ty) -> #output_ty, #impl_generics> #trait_name #ty_generics for ClosureWrapper<#input_ty, #output_ty, FN> {
#implemented_functions
}
}
}
ProducedTrait::Gen => quote! {},
ProducedTrait::Proc => {
let ty_generics = replaced_generics.split_for_impl().1;
let implemented_functions: TokenStream = input
.generic_names
.iter()
.map(|x| {
let ty_key = ident_to_type(x);
let fn_name = fn_name_by_type.get(&ty_key).unwrap();
quote! {
fn #fn_name(&mut self, val: &#input_ty) {
(self.func)(val)
}
}
})
.collect();
quote! {
impl <#input_ty: Clone, FN: FnMut(&#input_ty), #impl_generics> #trait_name #ty_generics for ClosureWrapper<&#input_ty, (), FN> {
#implemented_functions
}
}
}
}
};
quote! {
#[doc = #doc]
pub trait #trait_name #trait_generics {
#functions_to_implement
#[doc = #final_function_doc]
#final_function
}
#extra_impls
}
}
}
fn extra_expr_functions(input: &EnumDescription) -> TokenStream {
fn replace_type_generics(input: &EnumDescription, ident: &Ident) -> TokenStream {
let mut generic_replacer = TypeReplacer::new(
&input
.generic_names
.iter()
.map(|x| (x.clone(), ident.clone()))
.collect(),
);
generic_replacer
.fold_generics(input.generics.clone())
.split_for_impl()
.1
.to_token_stream()
}
let alone_ty = &input.generic_names[0];
let impl_generics = {
let mut generic_remover = GenericRemoval::new(&input.generic_names);
generic_remover.fold_generics(input.generics.clone())
};
let alone_impl_generics = &impl_generics.params;
let alone_type_generics = replace_type_generics(input, alone_ty);
let input_name = &input.name;
let apply_trait_name = ProducedTrait::Apply.trait_name(input);
let proc_trait_name = ProducedTrait::Proc.trait_name(input);
let new_ty = format_ident!("New{}", alone_ty);
let apply_output_type_generics = replace_type_generics(input, &new_ty);
quote! {
impl <#alone_ty: Clone, #alone_impl_generics> #input_name #alone_type_generics {
pub fn apply<#new_ty: Clone>(self, func: impl FnMut(#alone_ty) -> #new_ty) -> #input_name #apply_output_type_generics {
#apply_trait_name::apply(&mut ClosureWrapper::new(func), self)
}
pub fn proc(&self, func: impl FnMut(&#alone_ty)) {
#proc_trait_name::proc(&mut ClosureWrapper::new(func), self)
}
pub fn get_deps(&self) -> Vec<#alone_ty> {
let mut v: Vec<#alone_ty> = Vec::new();
self.proc(|x: &#alone_ty| v.push(x.clone()));
v
}
pub fn n_deps(&self) -> usize {
let mut res = 0usize;
self.proc(|_: &#alone_ty| res += 1);
res
}
}
}
}
fn ident_to_type(ident: &Ident) -> Type {
syn::parse2(ident.to_token_stream()).unwrap()
}
pub fn expr_macro_derive(item: TokenStream) -> TokenStream {
let input = syn::parse2::<DeriveInput>(item).unwrap();
let input = EnumDescription::new(input);
let fn_name_by_type = fn_name_by_type(&input);
let apply = ProducedTrait::Apply.produce(&input, &fn_name_by_type);
let random = ProducedTrait::Gen.produce(&input, &fn_name_by_type);
let proc = ProducedTrait::Proc.produce(&input, &fn_name_by_type);
let extra_impls = extra_expr_functions(&input);
quote! {
#apply
#random
#proc
struct ClosureWrapper<I, O, FN: FnMut(I) -> O> {
func: FN,
input_marker: PhantomData<I>,
output_marker: PhantomData<O>,
}
impl<I, O, FN: FnMut(I) -> O> ClosureWrapper<I, O, FN> {
fn new(func: FN) -> Self {
ClosureWrapper {
func,
input_marker: PhantomData,
output_marker: PhantomData,
}
}
}
#extra_impls
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_expr_macro() {
let x = quote! {
pub enum FieldExpr<F: UsedField, T: Clone, C: Clone = T, P: Clone = T> {
Input(InputId, InputInfo<F>),
Add(T, T),
Sub(T, T),
AddConst(T, F),
Mul(T, T),
MulConst(T, F),
LinComb(Vec<(T, F)>, F),
Gt(T, T),
Ge(T, T),
Rem(T, P),
Reveal(T),
Val(F),
Where(C, T, T),
Equal(T, T),
Neg(T),
Abs(T),
LogicalRightShift(T, usize),
Div(T, P),
Bounds(T, FieldBounds<F>),
SubCircuit(Vec<T>, ArithmeticCircuit<F>),
FieldInverse(P),
SumOfLinearCombinationOfMatrixProducts(Vec<(Vec<Vec<T>>, Vec<Vec<T>>, F)>)
}
};
let y = expr_macro_derive(x);
let syntax_tree = syn::parse_file(&y.to_string()).unwrap();
let formatted = prettyplease::unparse(&syntax_tree);
print!("{}", formatted);
}
}