higher_derive/
lib.rs

1#![recursion_limit = "256"]
2
3//! Custom derives for the [`higher`][higher] crate.
4//!
5//! Please see the relevant crate for documentation.
6//!
7//! [higher]: https://docs.rs/crate/higher
8
9extern crate proc_macro;
10
11use std::collections::HashMap;
12
13use proc_macro2::{Span, TokenStream};
14use quote::{quote, quote_spanned};
15use syn::{
16    parse_macro_input, punctuated::Punctuated, spanned::Spanned, token::Comma, Data, DataEnum,
17    DeriveInput, Field, Fields, FieldsNamed, FieldsUnnamed, GenericParam, Ident, Index, Type,
18    TypeParam,
19};
20
21fn type_params_replace(
22    input_params: &Punctuated<GenericParam, Comma>,
23    replace: &TypeParam,
24    with: Ident,
25) -> Punctuated<GenericParam, Comma> {
26    let mut output = input_params.clone();
27    for param in output.iter_mut() {
28        match param {
29            GenericParam::Type(ref mut type_param) if type_param == replace => {
30                *(&mut type_param.ident) = with;
31                break;
32            }
33            _ => {}
34        }
35    }
36    output
37}
38
39fn report_error(span: Span, msg: &str) -> proc_macro::TokenStream {
40    (quote_spanned! {span => compile_error! {#msg}}).into()
41}
42
43fn decide_functor_generic_type<'a>(
44    input: &'a DeriveInput,
45) -> Result<&'a TypeParam, proc_macro::TokenStream> {
46    let mut generics_iter = input.generics.type_params();
47    let generic_type = match generics_iter.next() {
48        Some(t) => t,
49        None => {
50            return Err(report_error(
51                input.ident.span(),
52                "can't derive Functor for a type without type parameters",
53            ));
54        }
55    };
56
57    if let Some(next_type_param) = generics_iter.next() {
58        return Err(report_error(
59            next_type_param.span(),
60            "can't derive Functor for a type with multiple type parameters; did you mean Bifunctor?",
61        ));
62    }
63
64    return Ok(generic_type);
65}
66
67fn decide_bifunctor_generic_types<'a>(
68    input: &'a DeriveInput,
69) -> Result<(&'a TypeParam, &'a TypeParam), proc_macro::TokenStream> {
70    let mut generics_iter = input.generics.type_params();
71    let generic_type_a = match generics_iter.next() {
72        Some(t) => t,
73        None => {
74            return Err(report_error(
75                input.ident.span(),
76                "can't derive Bifunctor for a type without type parameters",
77            ))
78        }
79    };
80
81    let generic_type_b = match generics_iter.next() {
82        Some(t) => t,
83        None => return Err(report_error(
84            input.ident.span(),
85            "can't derive Bifunctor for a type with only one type parameter; did you mean Functor?",
86        )),
87    };
88
89    if let Some(next_type_param) = generics_iter.next() {
90        return Err(report_error(
91            next_type_param.span(),
92            "can't derive Functor for a type with three or more type parameters",
93        ));
94    }
95
96    return Ok((generic_type_a, generic_type_b));
97}
98
99#[proc_macro_derive(Bifunctor)]
100pub fn derive_bifunctor(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
101    let input = parse_macro_input!(input as DeriveInput);
102    let name = &input.ident;
103    let type_params = &input.generics.params;
104    let where_clause = &input.generics.where_clause;
105
106    let (generic_type_a, generic_type_b) = match decide_bifunctor_generic_types(&input) {
107        Ok(t) => t,
108        Err(err) => return err,
109    };
110
111    let type_map = HashMap::from([
112        (
113            generic_type_a.ident.clone(),
114            Ident::new("left", Span::call_site()),
115        ),
116        (
117            generic_type_b.ident.clone(),
118            Ident::new("right", Span::call_site()),
119        ),
120    ]);
121
122    let bimap_impl = match &input.data {
123        Data::Struct(data) => match &data.fields {
124            Fields::Named(fields) => derive_functor_named_struct(name, fields, &type_map),
125            Fields::Unnamed(fields) => derive_functor_unnamed_struct(name, fields, &type_map),
126            Fields::Unit => {
127                return report_error(
128                    input.ident.span(),
129                    "can't derive Bifunctor for an empty struct",
130                );
131            }
132        },
133        Data::Enum(data) => derive_functor_enum(name, data, &type_map),
134        Data::Union(_) => {
135            return report_error(
136                input.ident.span(),
137                "can't derive Bifunctor for a union type",
138            );
139        }
140    };
141
142    let type_params_generic = type_params_replace(
143        &type_params_replace(
144            type_params,
145            generic_type_a,
146            Ident::new("DerivedTargetTypeA", Span::call_site()),
147        ),
148        generic_type_b,
149        Ident::new("DerivedTargetTypeB", Span::call_site()),
150    );
151
152    quote!(
153        impl<#type_params> ::higher::Bifunctor<'_, #generic_type_a, #generic_type_b> for #name<#type_params> #where_clause {
154            type Target<DerivedTargetTypeA, DerivedTargetTypeB> = #name<#type_params_generic>;
155            fn bimap<DerivedTypeA, DerivedTypeB, L, R>(self, left: L, right: R) -> Self::Target<DerivedTypeA, DerivedTypeB>
156            where
157                L: Fn(#generic_type_a) -> DerivedTypeA,
158                R: Fn(#generic_type_b) -> DerivedTypeB
159            {
160                #bimap_impl
161            }
162        }
163    )
164    .into()
165}
166
167#[proc_macro_derive(Functor)]
168pub fn derive_functor(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
169    let input = parse_macro_input!(input as DeriveInput);
170    let name = &input.ident;
171    let type_params = &input.generics.params;
172    let where_clause = &input.generics.where_clause;
173
174    let generic_type = match decide_functor_generic_type(&input) {
175        Ok(t) => t,
176        Err(err) => return err,
177    };
178
179    let type_map = HashMap::from([(
180        generic_type.ident.clone(),
181        Ident::new("f", Span::call_site()),
182    )]);
183
184    let fmap_impl = match &input.data {
185        Data::Struct(data) => match &data.fields {
186            Fields::Named(fields) => derive_functor_named_struct(name, fields, &type_map),
187            Fields::Unnamed(fields) => derive_functor_unnamed_struct(name, fields, &type_map),
188            Fields::Unit => {
189                return report_error(
190                    input.ident.span(),
191                    "can't derive Functor for an empty struct",
192                );
193            }
194        },
195        Data::Enum(data) => derive_functor_enum(name, data, &type_map),
196        Data::Union(_) => {
197            return report_error(input.ident.span(), "can't derive Functor for a union type");
198        }
199    };
200
201    let type_params_with_t = type_params_replace(
202        type_params,
203        generic_type,
204        Ident::new("DerivedTargetType", Span::call_site()),
205    );
206
207    quote!(
208        impl<#type_params> ::higher::Functor<'_, #generic_type> for #name<#type_params> #where_clause {
209            type Target<DerivedTargetType> = #name<#type_params_with_t>;
210            fn fmap<DerivedType, F>(self, f: F) -> Self::Target<DerivedType>
211            where
212                F: Fn(#generic_type) -> DerivedType
213            {
214                #fmap_impl
215            }
216        }
217    )
218    .into()
219}
220
221fn match_type_param<'a>(params: &'a HashMap<Ident, Ident>, ty: &Type) -> Option<&'a Ident> {
222    if let Type::Path(path) = ty {
223        if let Some(segment) = path.path.segments.iter().next() {
224            return params.get(&segment.ident);
225        }
226    }
227    None
228}
229
230fn filter_fields<'a, P, F1, F2>(
231    fields: &'a Punctuated<Field, P>,
232    ty: &HashMap<Ident, Ident>,
233    transform: F1,
234    copy: F2,
235) -> Vec<TokenStream>
236where
237    F1: Fn(&Ident, &Ident) -> TokenStream,
238    F2: Fn(&Ident) -> TokenStream,
239{
240    fields
241        .iter()
242        .map(|field| {
243            if let Some(f) = match_type_param(ty, &field.ty) {
244                transform(&field.ident.clone().unwrap(), f)
245            } else {
246                copy(&field.ident.clone().unwrap())
247            }
248        })
249        .collect()
250}
251
252fn derive_functor_named_struct(
253    name: &Ident,
254    fields: &FieldsNamed,
255    generic_types: &HashMap<Ident, Ident>,
256) -> TokenStream {
257    let apply_fields = filter_fields(
258        &fields.named,
259        generic_types,
260        |field, function_name| {
261            quote! {
262                #field: #function_name(self.#field),
263            }
264        },
265        |field| {
266            quote! {
267                #field: self.#field,
268            }
269        },
270    )
271    .into_iter();
272    quote! {
273        #name {
274            #(#apply_fields)*
275        }
276    }
277}
278
279fn derive_functor_unnamed_struct(
280    name: &Ident,
281    fields: &FieldsUnnamed,
282    generic_types: &HashMap<Ident, Ident>,
283) -> TokenStream {
284    let fields = fields.unnamed.iter().enumerate().map(|(index, field)| {
285        let index = Index::from(index);
286        if let Some(function_name) = match_type_param(generic_types, &field.ty) {
287            quote! { #function_name(self.#index), }
288        } else {
289            quote! { self.#index, }
290        }
291    });
292    quote! { #name(#(#fields)*) }
293}
294
295fn derive_functor_enum(
296    name: &Ident,
297    data: &DataEnum,
298    generic_types: &HashMap<Ident, Ident>,
299) -> TokenStream {
300    let variants = data.variants.iter().map(|variant| {
301        let ident = &variant.ident;
302        match &variant.fields {
303            Fields::Named(fields) => {
304                let args: Vec<Ident> = fields
305                    .named
306                    .iter()
307                    .map(|field| {
308                        Ident::new(
309                            &format!("arg_{}", field.ident.clone().unwrap()),
310                            field.ident.clone().unwrap().span(),
311                        )
312                    })
313                    .collect();
314                let apply =
315                    fields
316                        .named
317                        .iter()
318                        .zip(args.clone().into_iter())
319                        .map(|(field, arg)| {
320                            let name = &field.ident;
321                            if let Some(function_name) = match_type_param(generic_types, &field.ty)
322                            {
323                                quote! { #name: #function_name(#arg) }
324                            } else {
325                                quote! { #name: #arg }
326                            }
327                        });
328                let args = fields
329                    .named
330                    .iter()
331                    .zip(args.into_iter())
332                    .map(|(field, arg)| {
333                        let name = &field.ident;
334                        quote! { #name:#arg }
335                    });
336                quote! {
337                    #name::#ident { #(#args,)* } => #name::#ident { #(#apply,)* },
338                }
339            }
340            Fields::Unnamed(fields) => {
341                let args: Vec<Ident> = fields
342                    .unnamed
343                    .iter()
344                    .enumerate()
345                    .map(|(index, _)| Ident::new(&format!("arg{}", index), Span::call_site()))
346                    .collect();
347                let fields = fields.unnamed.iter().zip(args.iter()).map(|(field, arg)| {
348                    if let Some(function_name) = match_type_param(generic_types, &field.ty) {
349                        quote! { #function_name(#arg) }
350                    } else {
351                        quote! { #arg }
352                    }
353                });
354                let args = args.iter();
355                quote! {
356                    #name::#ident(#(#args,)*) => #name::#ident(#(#fields,)*),
357                }
358            }
359            Fields::Unit => quote! {
360                #name::#ident => #name::#ident,
361            },
362        }
363    });
364    quote! {
365        match self {
366            #(#variants)*
367        }
368    }
369}
370
371#[cfg(test)]
372mod test {
373    use higher::{Bifunctor, Functor};
374
375    #[derive(PartialEq, Eq, Debug, Functor)]
376    struct FunctorNamed<A> {
377        named: A,
378    }
379
380    #[derive(PartialEq, Eq, Debug, Functor)]
381    struct FunctorUnnamed<A>(A);
382
383    #[derive(PartialEq, Eq, Debug, Functor)]
384    #[allow(dead_code)]
385    enum FunctorEnum<A> {
386        Some(A),
387        SomeNumber(usize),
388        SomeOther(A),
389        None,
390    }
391
392    #[test]
393    fn derive_functor() {
394        assert_eq!(
395            (FunctorNamed { named: 2u32 }).fmap(|x| x + 3),
396            FunctorNamed { named: 5u32 }
397        );
398
399        assert_eq!(FunctorUnnamed(2u32).fmap(|x| x + 3), FunctorUnnamed(5u32));
400
401        assert_eq!(
402            FunctorEnum::Some(2u32).fmap(|x| x + 3),
403            FunctorEnum::Some(5u32)
404        );
405        assert_eq!(
406            FunctorEnum::<u32>::SomeNumber(2).fmap(|x| x + 3),
407            FunctorEnum::<u32>::SomeNumber(2)
408        );
409        assert_eq!(
410            FunctorEnum::SomeOther(2u32).fmap(|x| x + 3),
411            FunctorEnum::SomeOther(5u32)
412        );
413        assert_eq!(FunctorEnum::<u32>::None.fmap(|x| x + 3), FunctorEnum::None);
414    }
415
416    #[derive(PartialEq, Eq, Debug, Bifunctor)]
417    struct BifunctorNamed<A, B> {
418        a: A,
419        b: B,
420    }
421
422    #[derive(PartialEq, Eq, Debug, Bifunctor)]
423    struct BifunctorUnnamed<A, B>(A, B);
424
425    #[derive(PartialEq, Eq, Debug, Bifunctor)]
426    #[allow(dead_code)]
427    enum BifunctorEnum<A, B> {
428        Ok(A),
429        Err(B),
430        Number(usize),
431        Nothing,
432    }
433
434    #[test]
435    fn derive_bifunctor() {
436        assert_eq!(
437            (BifunctorNamed { a: 2u32, b: 2u8 }).bimap(|x| x + 3, |x| x + 4),
438            BifunctorNamed { a: 5u32, b: 6u8 }
439        );
440
441        assert_eq!(
442            BifunctorUnnamed(2u32, 2u8).bimap(|x| x + 3, |x| x + 4),
443            BifunctorUnnamed(5u32, 6u8)
444        );
445
446        assert_eq!(
447            BifunctorEnum::<u32, u8>::Ok(2u32).bimap(|x| x + 3, |x| x + 4),
448            BifunctorEnum::Ok(5u32)
449        );
450        assert_eq!(
451            BifunctorEnum::<u32, u8>::Err(2u8).bimap(|x| x + 3, |x| x + 4),
452            BifunctorEnum::Err(6u8)
453        );
454        assert_eq!(
455            BifunctorEnum::<u32, u8>::Number(2).bimap(|x| x + 3, |x| x + 4),
456            BifunctorEnum::Number(2)
457        );
458        assert_eq!(
459            BifunctorEnum::<u32, u8>::Nothing.bimap(|x| x + 3, |x| x + 4),
460            BifunctorEnum::Nothing
461        );
462    }
463}