osauth_derive/
lib.rs

1use std::fmt;
2
3use convert_case::{Case, Casing};
4use proc_macro::TokenStream;
5use proc_macro2::{Span, TokenStream as TS2};
6use quote::{quote, ToTokens};
7
8#[proc_macro_derive(
9    PaginatedResource,
10    attributes(resource_id, collection_name, flat_collection)
11)]
12pub fn paginated_resource_macro_derive(input: TokenStream) -> TokenStream {
13    let input = syn::parse_macro_input!(input as syn::DeriveInput);
14
15    let class_name = &input.ident;
16    let vis = &input.vis;
17    let maybe_collection_name = match get_collection_name(&input) {
18        Ok(name) => name,
19        Err(err) => return err.into_compile_error().into(),
20    };
21    let (id_name, id_type) = match get_id_field(&input) {
22        Ok(tpl) => tpl,
23        Err(err) => return err.into_compile_error().into(),
24    };
25
26    if let Some(collection_name) = maybe_collection_name {
27        let collection_ident = syn::Ident::new(&collection_name, Span::call_site());
28        let collection_class_name = syn::Ident::new(
29            &format!("{}DerivedOSResourceCollection", class_name),
30            Span::call_site(),
31        );
32
33        quote! {
34            #[derive(Debug, ::serde::Deserialize)]
35            #[allow(missing_docs, unused)]
36            #vis struct #collection_class_name {
37                #collection_ident: Vec<#class_name>,
38            }
39
40            #[allow(missing_docs, unused)]
41            impl ::osauth::PaginatedResource for #class_name {
42                type Id = #id_type;
43                type Root = #collection_class_name;
44                fn resource_id(&self) -> Self::Id {
45                    self.#id_name.clone()
46                }
47            }
48
49            #[allow(missing_docs, unused)]
50            impl From<#collection_class_name> for Vec<#class_name> {
51                fn from(value: #collection_class_name) -> Vec<#class_name> {
52                    value.#collection_ident
53                }
54            }
55        }
56    } else {
57        quote! {
58            #[allow(missing_docs, unused)]
59            impl ::osauth::PaginatedResource for #class_name {
60                type Id = #id_type;
61                type Root = Vec<#class_name>;
62                fn resource_id(&self) -> Self::Id {
63                    self.#id_name.clone()
64                }
65            }
66        }
67    }
68    .into()
69}
70
71fn get_attr<'a>(attrs: &'a [syn::Attribute], attr: &str) -> Option<&'a syn::Attribute> {
72    attrs.iter().find(|x| x.path.is_ident(attr))
73}
74
75fn get_id_field(input: &syn::DeriveInput) -> syn::Result<(&syn::Ident, &syn::Type)> {
76    let mut default_id = None;
77    if let syn::Data::Struct(ref st) = input.data {
78        if let syn::Fields::Named(ref fs) = st.fields {
79            for field in &fs.named {
80                if get_attr(&field.attrs, "resource_id").is_some() {
81                    return Ok((
82                        field.ident.as_ref().expect("no ident for resource_id"),
83                        &field.ty,
84                    ));
85                }
86
87                if let Some(id) = field.ident.as_ref() {
88                    if id == "id" {
89                        default_id = Some((id, &field.ty));
90                    }
91                }
92            }
93        } else {
94            return Err(syn::Error::new_spanned(
95                input,
96                "only named fields are supported for derive(PaginatedResource)",
97            ));
98        }
99    } else {
100        return Err(syn::Error::new_spanned(
101            input,
102            "only structs are supported for derive(PaginatedResource)",
103        ));
104    }
105
106    if let Some(id) = default_id {
107        Ok(id)
108    } else {
109        Err(syn::Error::new_spanned(input, "#[resource_id] missing"))
110    }
111}
112
113fn get_collection_name(input: &syn::DeriveInput) -> syn::Result<Option<String>> {
114    let mut flat = false;
115    let mut maybe_name = None;
116    for attr in &input.attrs {
117        match attr.parse_meta() {
118            Ok(syn::Meta::NameValue(nv)) if nv.path.is_ident("collection_name") => {
119                if flat {
120                    return Err(syn::Error::new_spanned(
121                        attr,
122                        "collection_name and flat_collection cannot be used together",
123                    ));
124                }
125                match nv.lit {
126                    syn::Lit::Str(s) => maybe_name = Some(s.value()),
127                    _ => {
128                        return Err(syn::Error::new_spanned(
129                            attr,
130                            "collection_name must be a string",
131                        ))
132                    }
133                }
134            }
135            Ok(syn::Meta::Path(p)) if p.is_ident("flat_collection") => {
136                if maybe_name.is_some() {
137                    return Err(syn::Error::new_spanned(
138                        attr,
139                        "collection_name and flat_collection cannot be used together",
140                    ));
141                }
142                flat = true;
143            }
144            _ => {}
145        }
146    }
147
148    Ok(if flat {
149        None
150    } else {
151        maybe_name.or_else(|| {
152            let ident = input.ident.to_string().to_case(Case::Snake);
153            Some(
154                if ident.chars().last().expect("empty collection_name") == 's' {
155                    format!("{}es", ident)
156                } else {
157                    format!("{}s", ident)
158                },
159            )
160        })
161    })
162}
163
164fn fail<S, M>(span: S, message: M) -> TokenStream
165where
166    S: ToTokens,
167    M: fmt::Display,
168{
169    syn::Error::new_spanned(span, message)
170        .into_compile_error()
171        .into()
172}
173
174#[proc_macro_derive(QueryItem, attributes(query_item))]
175pub fn query_item_macro_derive(input: TokenStream) -> TokenStream {
176    let input = syn::parse_macro_input!(input as syn::DeriveInput);
177
178    let class_name = &input.ident;
179    let fragments = match query_item_fragments(
180        class_name,
181        match input.data {
182            syn::Data::Enum(e) => e,
183            _ => {
184                return fail(input, "derive(QueryItem) only works on enums");
185            }
186        },
187    ) {
188        Ok(f) => f,
189        Err(e) => return e.into_compile_error().into(),
190    };
191
192    quote! {
193        impl ::osauth::QueryItem for #class_name {
194            fn query_item(&self) -> ::std::result::Result<(&str, ::std::borrow::Cow<str>), ::osauth::Error> {
195                Ok(match self {
196                    #(#fragments),*
197                })
198            }
199        }
200    }.into()
201}
202
203fn query_item_fragments(class_name: &syn::Ident, input: syn::DataEnum) -> syn::Result<Vec<TS2>> {
204    let mut result = Vec::with_capacity(input.variants.len());
205    for var in input.variants {
206        let name = if let Some(attr) = get_attr(&var.attrs, "query_item") {
207            match attr.parse_meta()? {
208                syn::Meta::NameValue(nv) => match nv.lit {
209                    syn::Lit::Str(s) => s.value(),
210                    _ => {
211                        return Err(syn::Error::new_spanned(
212                            attr,
213                            "query_item value must be a string",
214                        ));
215                    }
216                },
217                _ => {
218                    return Err(syn::Error::new_spanned(
219                        attr,
220                        "query_item must have a value",
221                    ));
222                }
223            }
224        } else {
225            var.ident.to_string().to_case(Case::Snake)
226        };
227        match var.fields {
228            syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
229                let field = fields.unnamed.into_iter().next().unwrap();
230                result.push(query_item_fragment(class_name, var.ident, &name, field));
231            }
232            _ => {
233                return Err(syn::Error::new_spanned(
234                    var,
235                    "each variant must have exactly one unnamed type",
236                ));
237            }
238        }
239    }
240    Ok(result)
241}
242
243fn query_item_fragment(
244    class_name: &syn::Ident,
245    ident: syn::Ident,
246    name: &str,
247    field: syn::Field,
248) -> TS2 {
249    let ty = field.ty;
250    match ty {
251        syn::Type::Path(tp) if tp.qself.is_none() && tp.path.is_ident("String") => {
252            quote! {
253                #class_name::#ident(var) => (#name, ::std::borrow::Cow::Borrowed(var.as_str()))
254            }
255        }
256        syn::Type::Path(tp) if tp.qself.is_none() && tp.path.is_ident("bool") => {
257            quote! {
258                #class_name::#ident(var) => {
259                    let value = if *var { "true" } else { "false" };
260                    (#name, ::std::borrow::Cow::Borrowed(value))
261                }
262            }
263        }
264        _ => quote! {
265            #class_name::#ident(var) => (#name, var.to_string().into())
266        },
267    }
268}