n_functor/
lib.rs

1use std::collections::BTreeMap;
2use proc_macro::TokenStream;
3use proc_macro2::{Span, TokenStream as TokenStream2};
4use quote::{quote, quote_spanned, ToTokens};
5use syn::parse::Parser;
6use syn::punctuated::Punctuated;
7use syn::spanned::Spanned;
8use syn::{
9    parse_macro_input, Attribute, Field,
10    Fields, GenericParam, Ident, Item, ItemEnum, ItemStruct, Meta, MetaList,
11    MetaNameValue, PathArguments, Token, Type, TypeParam, Variant,
12};
13
14/// Generate a `map` function for a given type that maps across all its type parameters.
15/// i.e.
16/// ```
17/// #[derive(...)]
18/// // optional: setting a name for the type parameters, doesn't affect the structure
19/// // of the data in any way, just the variable names.
20/// #[derive_n_functor(B = second_type_param)]
21/// // We can also choose a different map name, the default being `map`.
22/// // This will recurse down to child elements without a custom `map_with` declaration.
23/// // #[derive_n_functor(map_name = different_map)]
24/// struct Data<A, B> {
25///     a: A,
26///     // The map_with argument is an arbitrary expression.
27///     #[map_with(Option::map)]
28///     b: Option<B>
29/// }
30/// ```
31///
32/// Will generate a mapping function of the form: `Data::map(self, map_a: impl Fn(A) -> A2, map_second_type_param: impl B -> B2) -> Data<A2, B2>`.
33///
34/// See examples and use `cargo-expand` to see how different code generates.
35///
36/// Currently works with enums and structs. 
37/// 
38/// Caveats:
39/// - Does not work with data structures that have lifetimes or constants in them at this time. 
40/// - Does not currently work well with i.e. tuples where one of the types within is a type parameter. if you need to deal with this, write an external function that applies the mappings (see examples.)
41#[proc_macro_attribute]
42pub fn derive_n_functor(args: TokenStream, item: TokenStream) -> TokenStream {
43    let _args: TokenStream2 = args.clone().into();
44    let _item: TokenStream2 = item.clone().into();
45    let args = Args::from_token_stream(args);
46    let mut input = parse_macro_input!(item as Item);
47    let output = match &mut input {
48        Item::Enum(_enum) => AbstractFunctorFactory::from_item_enum(args, _enum),
49        Item::Struct(_struct) => AbstractFunctorFactory::from_item_struct(args, _struct),
50        _ => {
51            quote_spanned! {_args.span() => compile_error!("Could not derive n-functor for this, it is neither an enum or struct.")}
52        }
53    };
54    quote! {
55        #input
56        #output
57    }
58    .into()
59}
60
61struct Args {
62    pub parameter_names: BTreeMap<Ident, Ident>,
63    pub mapping_name: String,
64    // Alternative functions for mapping for specific fields.
65    // pub alt_functions: BTreeMap<Ident, TokenStream>,
66}
67
68impl Args {
69    fn from_token_stream(stream: TokenStream) -> Self {
70        let parsed_attrs: Punctuated<MetaNameValue, Token![,]> =
71            Parser::parse2(Punctuated::parse_terminated, stream.into()).unwrap();
72        Args::from_iter(parsed_attrs.into_iter())
73    }
74
75    fn from_iter(input: impl Iterator<Item = MetaNameValue>) -> Self {
76        let search_for_mapping_token = Ident::new("map_name", Span::call_site());
77        let mut mapping_name = "map".to_string();
78        let parameter_names = input
79            .filter_map(|name_val| {
80                if name_val.path.segments.last().unwrap().ident == search_for_mapping_token {
81                    // found the map renaming arg so skip this one after renaming mapping_name
82                    if let syn::Expr::Path(path) = name_val.value {
83                        mapping_name = path.path.segments.last()?.ident.to_string();
84                    }
85                    // return none as we've consumed this input.
86                    return None
87                }
88                // continue to processing 
89                let ty_ident = &name_val.path.segments.last()?.ident;
90                let rename_ident = &match &name_val.value {
91                    syn::Expr::Path(path) => path.path.segments.last(),
92                    _ => None,
93                }?
94                .ident;
95                Some((ty_ident.clone(), rename_ident.clone()))
96            })
97            .collect();
98        Self { parameter_names, mapping_name }
99    }
100
101    fn get_suffix_for(&self, ident: &Ident) -> Ident {
102        self.parameter_names
103            .get(ident)
104            .cloned()
105            .unwrap_or_else(|| Ident::new(&format!("{ident}"), Span::call_site()))
106    }
107
108    fn get_whole_map_name(&self, ident: &Ident) -> Ident {
109        let suffix = self.get_suffix_for(ident);
110        Ident::new(&format!("map_{suffix}"), Span::call_site())
111    }
112
113    fn get_map_all_name(&self) -> Ident {
114        Ident::new(&self.mapping_name, Span::call_site())
115    }
116}
117
118enum FieldMapping {
119    Trivial(Ident),
120    SubMapForArgs(Vec<Ident>),
121}
122
123type FieldNameMapping = Option<Vec<Ident>>;
124
125struct AbstractFunctorFactory {
126    pub args: Args,
127    // this is a vec for reasons of preserving order of type parameters.
128    pub type_maps_to_type: Vec<(Ident, Ident)>,
129    pub type_name: Ident,
130}
131
132impl AbstractFunctorFactory {
133    fn from_generics<'a>(
134        args: Args,
135        generics: impl Iterator<Item = &'a GenericParam>,
136        type_name: Ident,
137    ) -> Self {
138        let mut type_maps_to_type = vec![];
139        for generic in generics {
140            match generic {
141                GenericParam::Lifetime(_) => {}
142                GenericParam::Type(ty) => type_maps_to_type.push((
143                    ty.ident.clone(),
144                    Ident::new(&format!("{}2", ty.ident), Span::call_site()),
145                )),
146                GenericParam::Const(_) => {}
147            }
148        }
149        AbstractFunctorFactory {
150            args,
151            type_maps_to_type,
152            type_name,
153        }
154    }
155
156    fn from_item_enum(args: Args, source: &mut ItemEnum) -> TokenStream2 {
157        let name = source.ident.clone();
158        let factory = AbstractFunctorFactory::from_generics(
159            args,
160            source.generics.params.iter(),
161            source.ident.clone(),
162        );
163        let map_name = factory.args.get_map_all_name();
164        let (impl_gen, type_gen, where_clause) = source.generics.split_for_impl();
165        let mapped_params: Punctuated<TypeParam, Token![,]> = factory
166            .type_maps_to_type
167            .iter()
168            .map(|a| TypeParam::from(a.1.clone()))
169            .collect();
170        let fn_args = factory.make_fn_arguments();
171        let implemented: Punctuated<TokenStream2, Token![,]> = source
172            .variants
173            .iter_mut()
174            .map(|variant| factory.implement_body_for_variant(variant))
175            .collect();
176        quote! {
177            impl #impl_gen #name #type_gen #where_clause {
178                pub fn #map_name<#mapped_params>(self, #fn_args) -> #name<#mapped_params> {
179                    match self {
180                        #implemented
181                    }
182                }
183            }
184        }
185    }
186
187    fn from_item_struct(args: Args, source: &mut ItemStruct) -> TokenStream2 {
188        let name = source.ident.clone();
189        let factory = AbstractFunctorFactory::from_generics(
190            args,
191            source.generics.params.iter(),
192            source.ident.clone(),
193        );
194        let map_name = factory.args.get_map_all_name();
195        let (impl_gen, type_gen, where_clause) = source.generics.split_for_impl();
196        let mapped_params: Punctuated<TypeParam, Token![,]> = factory
197            .type_maps_to_type
198            .iter()
199            .map(|a| TypeParam::from(a.1.clone()))
200            .collect();
201        let fn_args = factory.make_fn_arguments();
202        let (fields, names_for_unnamed) = Self::unpack_fields(&source.fields);
203        let expanded = match source.fields {
204            Fields::Named(_) => quote! {#name {#fields}},
205            Fields::Unnamed(_) => quote! {#name(#fields)},
206            Fields::Unit => quote! {#name},
207        };
208        let implemented =
209            factory.apply_mapping_to_fields(&mut source.fields, name.clone(), names_for_unnamed);
210        quote! {
211            impl #impl_gen #name #type_gen #where_clause {
212                pub fn #map_name<#mapped_params>(self, #fn_args) -> #name<#mapped_params> {
213                    let #expanded = self;
214                    #implemented
215                }
216            }
217        }
218    }
219
220    fn implement_body_for_variant(&self, variant: &mut Variant) -> TokenStream2 {
221        let type_name = &self.type_name;
222        let name = &variant.ident;
223        let (unpacked, name_mapping) = Self::unpack_fields(&variant.fields);
224        match variant.fields {
225            Fields::Named(_) => {
226                let implemented =
227                    self.apply_mapping_to_fields(&mut variant.fields, name.clone(), name_mapping);
228                quote! {
229                    #type_name::#name{#unpacked} => #type_name::#implemented
230                }
231            }
232            Fields::Unnamed(_) => {
233                let implemented =
234                    self.apply_mapping_to_fields(&mut variant.fields, name.clone(), name_mapping);
235                quote! {
236                    #type_name::#name(#unpacked) => #type_name::#implemented
237                }
238            }
239            Fields::Unit => quote! {
240                #type_name::#name => #type_name::#name
241            },
242        }
243    }
244
245    /// The behaviour for this is such that the order of generics for the container type is followed best as possible.
246    fn get_mappable_generics_of_type(&self, ty: &Type) -> Option<FieldMapping> {
247        if let Type::Path(path) = ty {
248            let last_segment = path.path.segments.last();
249            // unwraps here because segments' length is checked to be >0 right here.
250            if path.path.segments.len() == 1
251                && self
252                    .type_maps_to_type
253                    .iter()
254                    .any(|(gen, _)| *gen == last_segment.unwrap().ident)
255            {
256                // the type is a path with 1 segment whose identifier matches a type parameter, so it's a trivial case.
257                return Some(FieldMapping::Trivial(last_segment.unwrap().ident.clone()));
258            }
259        }
260        let mut buffer = Vec::new();
261        self.recursive_get_generics_of_type_to_buffer(ty, &mut buffer);
262        (!buffer.is_empty()).then_some(FieldMapping::SubMapForArgs(buffer))
263    }
264
265    // needs to take a vector as its way of knowing what types have been found to preserve order within the
266    // recursed types.
267    fn recursive_get_generics_of_type_to_buffer(&self, ty: &Type, buffer: &mut Vec<Ident>) {
268        match ty {
269            Type::Array(array) => {
270                self.recursive_get_generics_of_type_to_buffer(&array.elem, buffer)
271            }
272            Type::Paren(paren) => {
273                self.recursive_get_generics_of_type_to_buffer(&paren.elem, buffer)
274            }
275            Type::Path(path) => {
276                if let Some(segment) = path.path.segments.last().filter(|segment| {
277                    self.type_maps_to_type
278                        .iter()
279                        .any(|(gen, _)| segment.ident == *gen)
280                }) {
281                    if !buffer.contains(&segment.ident) {
282                        buffer.push(segment.ident.clone());
283                    }
284                    if let PathArguments::AngleBracketed(generics) = &segment.arguments {
285                        for generic in &generics.args {
286                            if let syn::GenericArgument::Type(ty) = generic {
287                                self.recursive_get_generics_of_type_to_buffer(ty, buffer)
288                            }
289                        }
290                    }
291                }
292                // this needs to be out of the last check otherwise we don't properly recurse on non-type-params.
293                if let Some(PathArguments::AngleBracketed(generics)) = &path.path.segments.last().map(|segment| &segment.arguments) {
294                    for generic in &generics.args {
295                        if let syn::GenericArgument::Type(ty) = generic {
296                            self.recursive_get_generics_of_type_to_buffer(ty, buffer)
297                        }
298                    }
299                }
300            }
301            Type::Tuple(tuple) => {
302                for ty in &tuple.elems {
303                    self.recursive_get_generics_of_type_to_buffer(ty, buffer)
304                }
305            }
306            _ => {}
307        }
308    }
309
310    fn unpack_fields(fields: &Fields) -> (TokenStream2, FieldNameMapping) {
311        match fields {
312            Fields::Named(named) => {
313                let names: Punctuated<Ident, Token![,]> = named
314                    .named
315                    .iter()
316                    .map(|field| field.ident.clone().unwrap())
317                    .collect();
318                let tokens = quote! {
319                    #names
320                };
321                (tokens, None)
322            }
323            Fields::Unnamed(unnamed) => {
324                let faux_names: Punctuated<Ident, Token![,]> = unnamed
325                    .unnamed
326                    .iter()
327                    .enumerate()
328                    .map(|(num, _)| Ident::new(&format!("field_{num}"), Span::call_site()))
329                    .collect();
330                let tokens = quote! {
331                    #faux_names
332                };
333                (tokens, Some(faux_names.into_iter().collect()))
334            }
335            Fields::Unit => (quote! {}, None),
336        }
337    }
338
339    fn apply_mapping_to_fields(
340        &self,
341        fields: &mut Fields,
342        name: Ident,
343        names_for_unnamed: FieldNameMapping,
344    ) -> TokenStream2 {
345        match fields {
346            Fields::Named(named) => {
347                let mapped: Punctuated<TokenStream2, Token![,]> = named
348                    .named
349                    .iter_mut()
350                    .map(|field| {
351                        // we can unwrap as it's a named field.
352                        let field_name = field.ident.clone().unwrap();
353                        let new_field_content =
354                            self.apply_mapping_to_field_ref(field, quote! {#field_name});
355                        quote! {
356                            #field_name: #new_field_content
357                        }
358                    })
359                    .collect();
360                let implemented = mapped.to_token_stream();
361                quote! {
362                    #name {
363                        #implemented
364                    }
365                }
366            }
367            Fields::Unnamed(unnamed) => {
368                let names = names_for_unnamed.unwrap();
369                let mapped: Punctuated<TokenStream2, Token![,]> = unnamed
370                    .unnamed
371                    .iter_mut()
372                    .enumerate()
373                    .map(|(field_num, field)| {
374                        let name_of_field = &names[field_num];
375                        let field_ref = quote! {#name_of_field};
376                        let new_field_content = self.apply_mapping_to_field_ref(field, field_ref);
377                        quote! {
378                            #new_field_content
379                        }
380                    })
381                    .collect();
382                quote! {
383                    #name(#mapped)
384                }
385            }
386            Fields::Unit => quote! {#name},
387        }
388    }
389
390    fn apply_mapping_to_field_ref(
391        &self,
392        field: &mut Field,
393        field_ref: TokenStream2,
394    ) -> TokenStream2 {
395        match self.get_mappable_generics_of_type(&field.ty) {
396            Some(mapping) => match mapping {
397                FieldMapping::Trivial(type_to_map) => {
398                    let map = self.args.get_whole_map_name(&type_to_map);
399                    quote! {
400                        #map(#field_ref)
401                    }
402                }
403                // attempt recursion on the type.
404                FieldMapping::SubMapForArgs(sub_maps) => {
405                    let map_all_name = self.args.get_map_all_name();
406                    let all_fns: Punctuated<TokenStream2, Token![,]> = sub_maps
407                        .iter()
408                        .map(|ident| {
409                            let ident = self.args.get_whole_map_name(ident);
410                            quote! {&#ident}
411                        })
412                        .collect();
413                    match FieldArg::find_in_attributes(field.attrs.iter()) {
414                        Some(FieldArg { alt_function }) => {
415                            FieldArg::remove_from_attributes(&mut field.attrs);
416                            quote! {
417                                (#alt_function)(#field_ref, #all_fns)
418                            }
419                        }
420                        None => {
421                            quote! {
422                                #field_ref.#map_all_name(#all_fns)
423                            }
424                        }
425                    }
426                }
427            },
428            // There's no need to map, so we just move.
429            None => quote! {#field_ref},
430        }
431    }
432
433    fn make_fn_arguments(&self) -> TokenStream2 {
434        let mapped: Punctuated<TokenStream2, Token![,]> = self
435            .type_maps_to_type
436            .iter()
437            .map(|(from, to)| {
438                let fn_name = self.args.get_whole_map_name(from);
439                // it's this or TypedPat / PatTyped
440                // don't need to trailing comma this cos the punctuated type does that for us.
441                quote! {
442                    #fn_name: impl Fn(#from) -> #to
443                }
444            })
445            .collect();
446        mapped.into_token_stream()
447    }
448}
449
450struct FieldArg {
451    pub alt_function: TokenStream2,
452}
453
454impl FieldArg {
455    fn map_with_attr_ident() -> Ident {
456        Ident::new("map_with", Span::call_site())
457    }
458
459    fn remove_from_attributes(attributes: &mut Vec<Attribute>) {
460        let ident_to_check = Self::map_with_attr_ident();
461        // reverse the iterator so that we can remove indices easily.
462        let to_remove: Vec<_> = attributes
463            .iter()
464            .enumerate()
465            .rev()
466            .filter_map(|(num, attribute)| match &attribute.meta {
467                Meta::Path(_) => None,
468                Meta::List(meta) => {
469                    let last = meta.path.segments.last()?;
470                    (last.ident == ident_to_check).then_some(num)
471                }
472                Meta::NameValue(_) => None,
473            })
474            .collect();
475        for remove in to_remove {
476            attributes.remove(remove);
477        }
478    }
479
480    fn find_in_attributes<'a>(mut attributes: impl Iterator<Item = &'a Attribute>) -> Option<Self> {
481        attributes.find_map(|attribute| match &attribute.meta {
482            Meta::Path(_) => None,
483            Meta::List(meta) => Self::from_meta_list(meta),
484            Meta::NameValue(_) => None,
485        })
486    }
487
488    fn from_meta_list(meta: &MetaList) -> Option<Self> {
489        let ident_to_check = Self::map_with_attr_ident();
490        if meta.path.segments.iter().last().map(|x| &x.ident) == Some(&ident_to_check) {
491            Some(Self {
492                alt_function: meta.tokens.clone(),
493            })
494        } else {
495            None
496        }
497    }
498}
499// #[proc_macro_attribute]
500// pub fn derive_n_foldable(attr: TokenStream, item: TokenStream) -> TokenStream {
501//     unimplemented!()
502// }