functor_derive_lib/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use crate::generate_fmap_body::generate_fmap_body;
4use crate::map::{map_path, map_where};
5use crate::parse_attribute::parse_attribute;
6use proc_macro2::{Ident, TokenStream};
7use proc_macro_error::proc_macro_error;
8use quote::{format_ident, quote};
9use syn::token::Colon;
10use syn::{
11    parse_macro_input, Data, DeriveInput, Expr, ExprPath, GenericArgument, GenericParam, Path,
12    PathSegment, PredicateType, TraitBound, TraitBoundModifier, Type, TypeParamBound, TypePath,
13    WhereClause, WherePredicate,
14};
15
16mod generate_fmap_body;
17mod generate_map;
18mod map;
19mod parse_attribute;
20
21#[proc_macro_derive(Functor, attributes(functor))]
22#[proc_macro_error]
23pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
24    let input = parse_macro_input!(input as DeriveInput);
25
26    // Name of the Struct or Enum we are implementing the `Functor` trait for.
27    let def_name = input.ident.clone();
28
29    // Get the attributes for this invocation. If no attributes are given, the first generic is used as default.
30    let attribute = parse_attribute(&input);
31
32    // Get the generic parameters leaving only the bounds and attributes.
33    let source_params = input
34        .generics
35        .params
36        .iter()
37        .map(|param| match param {
38            GenericParam::Type(param) => {
39                let mut param = param.clone();
40                param.eq_token = None;
41                param.default = None;
42                GenericParam::Type(param)
43            }
44            GenericParam::Const(param) => {
45                let mut param = param.clone();
46                param.eq_token = None;
47                param.default = None;
48                GenericParam::Const(param)
49            }
50            param => param.clone(),
51        })
52        .collect::<Vec<_>>();
53
54    // Maps the generic parameters to generic arguments for the source.
55    let source_args = source_params
56        .iter()
57        .map(|param| match param {
58            GenericParam::Lifetime(l) => GenericArgument::Lifetime(l.lifetime.clone()),
59            GenericParam::Type(t) => GenericArgument::Type(Type::Path(TypePath {
60                qself: None,
61                path: Path::from(PathSegment::from(t.ident.clone())),
62            })),
63            GenericParam::Const(c) => GenericArgument::Const(Expr::Path(ExprPath {
64                attrs: vec![],
65                qself: None,
66                path: Path::from(PathSegment::from(c.ident.clone())),
67            })),
68        })
69        .collect::<Vec<_>>();
70
71    // Lints
72    let lints = quote! {
73        #[allow(absolute_paths_not_starting_with_crate)]
74        #[allow(bare_trait_objects)]
75        #[allow(deprecated)]
76        #[allow(drop_bounds)]
77        #[allow(dyn_drop)]
78        #[allow(non_camel_case_types)]
79        #[allow(trivial_bounds)]
80        #[allow(unused_qualifications)]
81        #[allow(clippy::allow)]
82        #[automatically_derived]
83    };
84
85    let mut tokens = TokenStream::new();
86
87    // Include default Functor implementation.
88    if let Some(default) = attribute.default {
89        tokens.extend(generate_default_impl(
90            &default,
91            &def_name,
92            &source_params,
93            &source_args,
94            &input.generics.where_clause,
95            &lints,
96        ));
97    }
98
99    // Include all named implementations.
100    for (param, name) in attribute.name_map {
101        tokens.extend(generate_named_impl(
102            &param,
103            &name,
104            &def_name,
105            &source_params,
106            &source_args,
107            &input.generics.where_clause,
108            &lints,
109        ));
110    }
111
112    // Include internal implementations.
113    tokens.extend(generate_refs_impl(
114        &input.data,
115        &def_name,
116        &source_params,
117        &source_args,
118        &input.generics.where_clause,
119        &lints,
120    ));
121
122    tokens.into()
123}
124
125fn find_index(source_params: &[GenericParam], ident: &Ident) -> usize {
126    for (total, param) in source_params.iter().enumerate() {
127        match param {
128            GenericParam::Type(t) if &t.ident == ident => return total,
129            _ => {}
130        }
131    }
132    unreachable!()
133}
134
135fn generate_refs_impl(
136    data: &Data,
137    def_name: &Ident,
138    source_params: &Vec<GenericParam>,
139    source_args: &Vec<GenericArgument>,
140    where_clause: &Option<WhereClause>,
141    lints: &TokenStream,
142) -> TokenStream {
143    let mut tokens = TokenStream::new();
144    for param in source_params {
145        if let GenericParam::Type(t) = param {
146            let param_ident = t.ident.clone();
147            let param_idx = find_index(source_params, &t.ident);
148
149            let functor_trait_ident = format_ident!("Functor{param_idx}");
150            let fmap_ident = format_ident!("__fmap_{param_idx}_ref");
151            let try_fmap_ident = format_ident!("__try_fmap_{param_idx}_ref");
152
153            // Generate body of the `fmap` implementation.
154            let Some(fmap_ref_body) = generate_fmap_body(data, def_name, &param_ident, false)
155            else {
156                continue;
157            };
158            let Some(try_fmap_ref_body) = generate_fmap_body(data, def_name, &param_ident, true)
159            else {
160                continue;
161            };
162
163            let mut target_args = source_args.clone();
164            target_args[param_idx] = GenericArgument::Type(Type::Path(TypePath {
165                qself: None,
166                path: Path::from(PathSegment::from(format_ident!("__B"))),
167            }));
168
169            if let Some(fn_where_clause) =
170                create_fn_where_clause(where_clause, source_params, &param_ident)
171            {
172                tokens.extend(quote!(
173                    #lints
174                    impl<#(#source_params),*> #def_name<#(#source_args),*> #where_clause {
175                        pub fn #fmap_ident<__B>(self, __f: &impl Fn(#param_ident) -> __B) -> #def_name<#(#target_args),*> #fn_where_clause {
176                            use ::functor_derive::*;
177                            #fmap_ref_body
178                        }
179
180                        pub fn #try_fmap_ident<__B, __E>(self, __f: &impl Fn(#param_ident) -> Result<__B, __E>) -> Result<#def_name<#(#target_args),*>, __E> #fn_where_clause {
181                            use ::functor_derive::*;
182                            Ok(#try_fmap_ref_body)
183                        }
184                    }
185                ))
186            } else {
187                tokens.extend(quote!(
188                    #lints
189                    impl<#(#source_params),*> ::functor_derive::#functor_trait_ident<#param_ident> for #def_name<#(#source_args),*> #where_clause {
190                        type Target<__B> = #def_name<#(#target_args),*>;
191
192                        fn #fmap_ident<__B>(self, __f: &impl Fn(#param_ident) -> __B) -> #def_name<#(#target_args),*> {
193                            use ::functor_derive::*;
194                            #fmap_ref_body
195                        }
196
197                        fn #try_fmap_ident<__B, __E>(self, __f: &impl Fn(#param_ident) -> Result<__B, __E>) -> Result<#def_name<#(#target_args),*>, __E> {
198                            use ::functor_derive::*;
199                            Ok(#try_fmap_ref_body)
200                        }
201                    }
202                ))
203            }
204        }
205    }
206    tokens
207}
208
209fn generate_default_impl(
210    param: &Ident,
211    def_name: &Ident,
212    source_params: &Vec<GenericParam>,
213    source_args: &Vec<GenericArgument>,
214    where_clause: &Option<WhereClause>,
215    lints: &TokenStream,
216) -> TokenStream {
217    let default_idx = find_index(source_params, param);
218
219    // Create generic arguments for the target. We use `__B` for the mapped generic.
220    let mut target_args = source_args.clone();
221    target_args[default_idx] = GenericArgument::Type(Type::Path(TypePath {
222        qself: None,
223        path: Path::from(PathSegment::from(format_ident!("__B"))),
224    }));
225
226    let default_map = format_ident!("__fmap_{default_idx}_ref");
227    let default_try_map = format_ident!("__try_fmap_{default_idx}_ref");
228
229    if let Some(fn_where_clause) = create_fn_where_clause(where_clause, source_params, param) {
230        quote!(
231            #lints
232            impl<#(#source_params),*> #def_name<#(#source_args),*> #where_clause {
233                pub fn fmap<__B>(self, __f: impl Fn(#param) -> __B) -> #def_name<#(#target_args),*> #fn_where_clause {
234                    use ::functor_derive::*;
235                    self.#default_map(&__f)
236                }
237
238                pub fn try_fmap<__B, __E>(self, __f: impl Fn(#param) -> Result<__B, __E>) -> Result<#def_name<#(#target_args),*>, __E> #fn_where_clause {
239                    use ::functor_derive::*;
240                    self.#default_try_map(&__f)
241                }
242            }
243        )
244    } else {
245        quote!(
246            #lints
247            impl<#(#source_params),*> ::functor_derive::Functor<#param> for #def_name<#(#source_args),*> {
248                type Target<__B> = #def_name<#(#target_args),*>;
249
250                fn fmap<__B>(self, __f: impl Fn(#param) -> __B) -> #def_name<#(#target_args),*> {
251                    use ::functor_derive::*;
252                    self.#default_map(&__f)
253                }
254
255                fn try_fmap<__B, __E>(self, __f: impl Fn(#param) -> Result<__B, __E>) -> Result<#def_name<#(#target_args),*>, __E> {
256                    use ::functor_derive::*;
257                    self.#default_try_map(&__f)
258                }
259            }
260        )
261    }
262}
263
264fn generate_named_impl(
265    param: &Ident,
266    name: &Ident,
267    def_name: &Ident,
268    source_params: &Vec<GenericParam>,
269    source_args: &Vec<GenericArgument>,
270    where_clause: &Option<WhereClause>,
271    lints: &TokenStream,
272) -> TokenStream {
273    let default_idx = find_index(source_params, param);
274
275    // Create generic arguments for the target. We use `__B` for the mapped generic.
276    let mut target_args = source_args.clone();
277    target_args[default_idx] = GenericArgument::Type(Type::Path(TypePath {
278        qself: None,
279        path: Path::from(PathSegment::from(format_ident!("__B"))),
280    }));
281
282    let fmap_name = format_ident!("fmap_{name}");
283    let try_fmap_name = format_ident!("try_fmap_{name}");
284
285    let fmap = format_ident!("__fmap_{default_idx}_ref");
286    let fmap_try = format_ident!("__try_fmap_{default_idx}_ref");
287
288    let fn_where_clause = create_fn_where_clause(where_clause, source_params, param);
289
290    quote!(
291        #lints
292        impl<#(#source_params),*> #def_name<#(#source_args),*> #where_clause {
293            pub fn #fmap_name<__B>(self, __f: impl Fn(#param) -> __B) -> #def_name<#(#target_args),*> #fn_where_clause {
294                use ::functor_derive::*;
295                self.#fmap(&__f)
296            }
297
298            pub fn #try_fmap_name<__B, __E>(self, __f: impl Fn(#param) -> Result<__B, __E>) -> Result<#def_name<#(#target_args),*>, __E> #fn_where_clause {
299                use ::functor_derive::*;
300                self.#fmap_try(&__f)
301            }
302        }
303    )
304}
305
306fn create_fn_where_clause(
307    where_clause: &Option<WhereClause>,
308    source_params: &Vec<GenericParam>,
309    param: &Ident,
310) -> Option<WhereClause> {
311    let mut predicates = where_clause
312        .iter()
313        .flat_map(|where_clause| map_where(where_clause, param))
314        .flat_map(|where_clause| where_clause.predicates)
315        .collect::<Vec<_>>();
316
317    for source_param in source_params {
318        if let GenericParam::Type(typ) = source_param {
319            if typ.bounds.is_empty() {
320                continue;
321            };
322
323            let bounds = typ
324                .bounds
325                .iter()
326                .cloned()
327                .flat_map(|bound| {
328                    if let TypeParamBound::Trait(mut trt) = bound {
329                        match trt.modifier {
330                            TraitBoundModifier::Maybe(_) => None,
331                            TraitBoundModifier::None => {
332                                map_path(&mut trt.path, param, &mut false);
333                                Some(TypeParamBound::Trait(trt))
334                            }
335                        }
336                    } else {
337                        Some(bound)
338                    }
339                })
340                .collect();
341
342            predicates.push(WherePredicate::Type(PredicateType {
343                lifetimes: None,
344                bounded_ty: Type::Path(TypePath {
345                    qself: None,
346                    path: Path {
347                        leading_colon: None,
348                        segments: [PathSegment {
349                            ident: if &typ.ident == param {
350                                format_ident!("__B")
351                            } else {
352                                typ.ident.clone()
353                            },
354                            arguments: Default::default(),
355                        }]
356                        .into_iter()
357                        .collect(),
358                    },
359                }),
360                colon_token: Colon::default(),
361                bounds,
362            }))
363        }
364    }
365
366    // Add param: Sized
367    predicates.push(WherePredicate::Type(PredicateType {
368        lifetimes: None,
369        bounded_ty: Type::Path(TypePath {
370            qself: None,
371            path: Path {
372                leading_colon: None,
373                segments: [PathSegment {
374                    ident: param.clone(),
375                    arguments: Default::default(),
376                }]
377                .into_iter()
378                .collect(),
379            },
380        }),
381        colon_token: Colon::default(),
382        bounds: [TypeParamBound::Trait(TraitBound {
383            paren_token: None,
384            modifier: TraitBoundModifier::None,
385            lifetimes: None,
386            path: Path {
387                leading_colon: None,
388                segments: [PathSegment {
389                    ident: format_ident!("Sized"),
390                    arguments: Default::default(),
391                }]
392                .into_iter()
393                .collect(),
394            },
395        })]
396        .into_iter()
397        .collect(),
398    }));
399
400    if predicates.is_empty() {
401        None
402    } else {
403        Some(WhereClause {
404            where_token: Default::default(),
405            predicates: predicates.into_iter().collect(),
406        })
407    }
408}