rsexp_derive/
lib.rs

1// This deriver is used to convert between some struct/enum types and the Sexp type.
2// It might be more efficient to write a direct serialization/deserialization deriver,
3// directly or via serde.
4//
5// TODO: support sexp.option, default values, allow extra fields, etc.
6extern crate proc_macro;
7
8use proc_macro::TokenStream;
9use quote::{format_ident, quote};
10use syn::{
11    parse_quote, DataEnum, DataUnion, DeriveInput, FieldsNamed, FieldsUnnamed, GenericParam,
12};
13
14#[proc_macro_derive(SexpOf)]
15pub fn sexp_of_derive(input: TokenStream) -> TokenStream {
16    let ast = syn::parse(input).unwrap();
17    impl_sexp_of(&ast)
18}
19
20fn impl_sexp_of(ast: &DeriveInput) -> TokenStream {
21    let DeriveInput { ident, data, generics, .. } = ast;
22    let mut generics = generics.clone();
23    for param in &mut generics.params {
24        if let GenericParam::Type(type_param) = param {
25            type_param.bounds.push(parse_quote!(SexpOf))
26        }
27    }
28    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
29    let impl_fn = match data {
30        syn::Data::Struct(s) => match &s.fields {
31            syn::Fields::Named(FieldsNamed { named, .. }) => {
32                let fields = named.iter().map(|field| {
33                    let name = field.ident.as_ref().unwrap();
34                    let name_str = name.to_string();
35                    quote! { rsexp::list(&[rsexp::atom(#name_str.as_bytes()), self.#name.sexp_of()]) }
36                });
37                quote! {rsexp::list(&[#(#fields),*])}
38            }
39            syn::Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
40                let num_fields = unnamed.len();
41                let fields = (0..num_fields).map(|index| {
42                    let index = syn::Index::from(index);
43                    quote! { self.#index.sexp_of() }
44                });
45                quote! {rsexp::list(&[#(#fields),*])}
46            }
47            syn::Fields::Unit => {
48                unimplemented!()
49            }
50        },
51        syn::Data::Enum(DataEnum { variants, .. }) => {
52            let cases = variants.iter().map(|variant| {
53                let variant_ident = &variant.ident;
54                let variant_bytes = syn::LitByteStr::new(variant_ident.to_string().as_bytes(), variant_ident.span());
55                let cstor = quote! { rsexp::atom(#variant_bytes) };
56                let (pattern, sexp) = match &variant.fields {
57                    syn::Fields::Named(FieldsNamed { named, .. }) => {
58                        let args = named.iter().map(|field| field.ident.as_ref().unwrap());
59                        let fields = named.iter().map(|field| {
60                            let name = field.ident.as_ref().unwrap();
61                            let name_str = name.to_string();
62                            quote! { rsexp::list(&[rsexp::atom(#name_str.as_bytes()), #name.sexp_of()]) }
63                        });
64                        let sexp =
65                            if variant.fields.is_empty() {
66                                quote! { #cstor }
67                            } else {
68                                quote! { rsexp::list(&[#cstor, #(#fields),*]) }
69                            };
70                        (quote! { { #(#args),* } }, sexp)
71                    }
72                    syn::Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
73                        let num_fields = unnamed.len();
74                        let args = (0..num_fields).map(|index| format_ident!("arg{}", index));
75                        let fields = args.clone().map(|arg| quote! { #arg.sexp_of() });
76                        let sexp =
77                            if num_fields == 0 {
78                                quote! { #cstor }
79                            } else {
80                                quote! { rsexp::list(&[#cstor, #(#fields),*]) }
81                            };
82                        (quote! { (#(#args),*) }, sexp)
83                    }
84                    syn::Fields::Unit => (quote! {}, quote! { #cstor }),
85                };
86                quote! {
87                    #ident::#variant_ident #pattern => { #sexp }
88                }
89            });
90            quote! {
91                match self {
92                    #(#cases)*
93                }
94            }
95        }
96        syn::Data::Union(DataUnion { union_token, .. }) => {
97            return syn::Error::new_spanned(&union_token, "union is not supported")
98                .to_compile_error()
99                .into();
100        }
101    };
102
103    let output = quote! {
104        impl #impl_generics rsexp::SexpOf for #ident #ty_generics #where_clause {
105            fn sexp_of(&self) -> rsexp::Sexp {
106                #impl_fn
107            }
108        }
109    };
110
111    output.into()
112}
113
114#[proc_macro_derive(OfSexp)]
115pub fn of_sexp_derive(input: TokenStream) -> TokenStream {
116    let ast = syn::parse(input).unwrap();
117    impl_of_sexp(&ast)
118}
119
120// This assumes that __fields has been defined as a &[Sexp]
121fn impl_named_struct_of_sexp(
122    fields_named: &syn::FieldsNamed,
123    output_ident: proc_macro2::TokenStream,
124) -> proc_macro2::TokenStream {
125    let named = &fields_named.named;
126    let ident_str = output_ident.to_string();
127    let fields = named.iter().map(|field| field.ident.as_ref().unwrap());
128    let mk_fields = named.iter().map(|field| {
129        let name = field.ident.as_ref().unwrap();
130        let name_str = name.to_string();
131        quote! {
132            let #name = match __map.remove(#name_str.as_bytes()) {
133                Some(sexp) => rsexp::OfSexp::of_sexp(sexp)?,
134                None => return Err(rsexp::IntoSexpError::MissingFieldsInStruct {
135                    type_: #ident_str,
136                    field: #name_str,
137                })
138            };
139        }
140    });
141    quote! {
142        let mut __map: std::collections::HashMap<&[u8], &rsexp::Sexp> = rsexp::Sexp::extract_map(__fields, #ident_str)?;
143        #(#mk_fields)*
144        if !__map.is_empty() {
145            let mut extra_fields: Vec<_> = __map.into_keys().map(|x| String::from_utf8_lossy(x).to_string()).collect();
146            extra_fields.sort();
147            return Err(rsexp::IntoSexpError::ExtraFieldsInStruct {
148                type_: #ident_str,
149                extra_fields,
150            })
151        }
152        Ok(#output_ident { #(#fields),* })
153    }
154}
155
156fn impl_unnamed_struct_of_sexp(
157    fields_unnamed: &syn::FieldsUnnamed,
158    output_ident: proc_macro2::TokenStream,
159) -> proc_macro2::TokenStream {
160    let unnamed = &fields_unnamed.unnamed;
161    let ident_str = output_ident.to_string();
162
163    let num_fields = unnamed.len();
164    let fields = (0..num_fields).map(|index| format_ident!("__field{}", index));
165    let fields_ = fields.clone();
166    let fields_list = quote! { #(rsexp::OfSexp::of_sexp(#fields)?),*};
167    quote! {
168        match __fields {
169            [#(#fields_,)*] => Ok(#output_ident(#fields_list)),
170            l => Err(rsexp::IntoSexpError::ListLengthMismatch {
171                type_: #ident_str,
172                expected_len: #num_fields,
173                list_len: l.len(),
174            }),
175        }
176    }
177}
178fn impl_of_sexp(ast: &DeriveInput) -> TokenStream {
179    let DeriveInput { ident, data, generics, .. } = ast;
180    let ident_str = ident.to_string();
181    let mut generics = generics.clone();
182    for param in &mut generics.params {
183        if let GenericParam::Type(type_param) = param {
184            type_param.bounds.push(parse_quote!(OfSexp))
185        }
186    }
187    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
188
189    let of_sexp_fn = match data {
190        syn::Data::Struct(s) => match &s.fields {
191            syn::Fields::Named(f) => {
192                let result = impl_named_struct_of_sexp(&f, quote! {#ident});
193                quote! {
194                    let __fields = __s.extract_list(#ident_str)?;
195                    #result
196                }
197            }
198            syn::Fields::Unnamed(f) => {
199                let result = impl_unnamed_struct_of_sexp(&f, quote! {#ident});
200                quote! {
201                    let __fields = __s.extract_list(#ident_str)?;
202                    #result
203                }
204            }
205            syn::Fields::Unit => quote! {#ident},
206        },
207        syn::Data::Enum(DataEnum { variants, .. }) => {
208            let cases = variants.iter().map(|variant| {
209                let variant_ident = &variant.ident;
210                let variant_bytes = syn::LitByteStr::new(
211                    variant_ident.to_string().as_bytes(),
212                    variant_ident.span(),
213                );
214                let branch = match &variant.fields {
215                    syn::Fields::Named(f) => {
216                        impl_named_struct_of_sexp(&f, quote! {#ident::#variant_ident})
217                    }
218                    syn::Fields::Unnamed(f) => {
219                        impl_unnamed_struct_of_sexp(&f, quote! {#ident::#variant_ident})
220                    }
221                    syn::Fields::Unit => quote! {#ident::#variant_ident},
222                };
223                quote! {
224                    (#variant_bytes, __fields) => {
225                        #branch
226                    }
227                }
228            });
229            quote! {
230            match __s.extract_enum(#ident_str)? {
231                #(#cases)*
232                (ctor, _) =>
233                    Err(rsexp::IntoSexpError::UnknownConstructorForEnum {
234                        type_: #ident_str,
235                        constructor: String::from_utf8_lossy(ctor).to_string(),
236                    }),
237                }
238            }
239        }
240        syn::Data::Union(DataUnion { union_token, .. }) => {
241            return syn::Error::new_spanned(&union_token, "union is not supported")
242                .to_compile_error()
243                .into();
244        }
245    };
246
247    let output = quote! {
248        impl #impl_generics rsexp::OfSexp for #ident #ty_generics #where_clause {
249            fn of_sexp(__s: &rsexp::Sexp) -> std::result::Result<Self, rsexp::IntoSexpError> {
250                #of_sexp_fn
251            }
252        }
253    };
254
255    output.into()
256}