metaemu_state_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::quote;
6use syn::{parse_macro_input, DeriveInput, Ident, Token};
7use syn::parse::{Parse, ParseStream};
8use syn::spanned::Spanned;
9
10use itertools::Itertools;
11
12#[derive(Clone)]
13enum Transform {
14    Nothing,
15    Single(syn::ExprClosure),
16    Pair(syn::ExprClosure, syn::ExprClosure),
17}
18
19fn parse_transform(t: ParseStream) -> syn::Result<Transform> {
20    let first = syn::ExprClosure::parse(t)?;
21    let peek = t.lookahead1();
22    if peek.peek(Token![,]) {
23        let _ = t.parse::<Token![,]>()?;
24        Ok(Transform::Pair(first, syn::ExprClosure::parse(t)?))
25    } else {
26        Ok(Transform::Single(first))
27    }
28}
29
30#[proc_macro_derive(AsState, attributes(fugue))]
31pub fn derive_as_state(input: TokenStream) -> TokenStream {
32    let input = parse_macro_input!(input as DeriveInput);
33    let span = input.span();
34    let ident = &input.ident;
35    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
36    let marked_fields = if let syn::Data::Struct(struc) = input.data {
37        let fields = struc.fields;
38        let marked = fields.into_iter()
39            .enumerate()
40            .filter_map(|(i, field)| {
41                field.attrs.iter()
42                    .find_map(|attr| if attr.path.is_ident("fugue") {
43                        if attr.tokens.is_empty() {
44                            Some(Ok((i, Transform::Nothing, field.clone())))
45                        } else {
46                            let attr_map = attr.parse_args_with(parse_transform);
47                            Some(attr_map.map(|map| (i, map, field.clone())))
48                        }
49                    } else {
50                        None
51                    })
52            })
53            .collect::<syn::Result<Vec<_>>>();
54
55        match marked {
56            Ok(marked) => marked,
57            Err(e) => return e.into_compile_error().into(),
58        }
59    } else {
60        return syn::Error::new(span, "only structs are supported")
61            .into_compile_error()
62            .into()
63    };
64
65    let marked_powerset = marked_fields.into_iter()
66        .powerset()
67        .filter(|v| !v.is_empty());
68
69    marked_powerset.into_iter().flat_map(|v| { let n = v.len(); v.into_iter().permutations(n).collect::<Vec<_>>() }).map(|fields| {
70        let types = fields.iter().map(|(_, _, f)| f.ty.clone()).collect::<Vec<_>>();
71        let (type_tokens, refs, muts) = if types.len() == 1 {
72            let ty = &types[0];
73            (
74                quote! { #ty },
75                quote! { &#ty },
76                quote! { &mut #ty },
77            )
78        } else {
79            (
80                quote! { #(#types),* },
81                quote! { (#(&#types),*) },
82                quote! { (#(&mut #types),*) },
83            )
84        };
85
86        let (accessors, accessors_mut) = fields.iter()
87            .map(|(i, ff, f)| if let Some(ref ident) = f.ident {
88                match ff {
89                    Transform::Nothing => (quote! { &self.#ident }, quote! { &mut self.#ident }),
90                    Transform::Single(ff) => (quote! { (#ff)(&self.#ident) }, quote! { (#ff)(&mut self.#ident) }),
91                    Transform::Pair(ffr, ffm) => (quote! { (#ffr)(&self.#ident) }, quote! { (#ffm)(&mut self.#ident) }),
92                }
93            } else {
94                match ff {
95                    Transform::Nothing => (quote! { &self.#i }, quote! { &mut self.#i }),
96                    Transform::Single(ff) => (quote! { (#ff)(&self.#i) }, quote! { (#ff)(&mut self.#i) }),
97                    Transform::Pair(ffr, ffm) => (quote! { (#ffr)(&self.#i) }, quote! { (#ffm)(&mut self.#i) }),
98                }
99            })
100            .unzip::<_, _, Vec<_>, Vec<_>>();
101
102        let arity = types.len();
103        let impl_fn = if arity > 1 { Ident::new(&format!("state{}_ref", arity), span) } else { Ident::new("state_ref", span) };
104        let impl_fn_mut = if arity > 1 { Ident::new(&format!("state{}_mut", arity), span) } else { Ident::new("state_mut", span) };
105        let impl_trait = if arity > 1 { Ident::new(&format!("AsState{}", types.len()), span) } else { Ident::new("AsState", span) };
106
107        quote! {
108            impl #impl_generics ::fugue_state::#impl_trait<#type_tokens> for #ident #ty_generics #where_clause {
109                fn #impl_fn(&self) -> #refs {
110                    (#(#accessors),*)
111                }
112
113                fn #impl_fn_mut(&mut self) -> #muts {
114                    (#(#accessors_mut),*)
115                }
116            }
117        }
118    })
119    .collect::<TokenStream2>()
120    .into()
121}