1use std::fmt;
2
3use convert_case::{Case, Casing};
4use proc_macro::TokenStream;
5use proc_macro2::{Span, TokenStream as TS2};
6use quote::{quote, ToTokens};
7
8#[proc_macro_derive(
9 PaginatedResource,
10 attributes(resource_id, collection_name, flat_collection)
11)]
12pub fn paginated_resource_macro_derive(input: TokenStream) -> TokenStream {
13 let input = syn::parse_macro_input!(input as syn::DeriveInput);
14
15 let class_name = &input.ident;
16 let vis = &input.vis;
17 let maybe_collection_name = match get_collection_name(&input) {
18 Ok(name) => name,
19 Err(err) => return err.into_compile_error().into(),
20 };
21 let (id_name, id_type) = match get_id_field(&input) {
22 Ok(tpl) => tpl,
23 Err(err) => return err.into_compile_error().into(),
24 };
25
26 if let Some(collection_name) = maybe_collection_name {
27 let collection_ident = syn::Ident::new(&collection_name, Span::call_site());
28 let collection_class_name = syn::Ident::new(
29 &format!("{}DerivedOSResourceCollection", class_name),
30 Span::call_site(),
31 );
32
33 quote! {
34 #[derive(Debug, ::serde::Deserialize)]
35 #[allow(missing_docs, unused)]
36 #vis struct #collection_class_name {
37 #collection_ident: Vec<#class_name>,
38 }
39
40 #[allow(missing_docs, unused)]
41 impl ::osauth::PaginatedResource for #class_name {
42 type Id = #id_type;
43 type Root = #collection_class_name;
44 fn resource_id(&self) -> Self::Id {
45 self.#id_name.clone()
46 }
47 }
48
49 #[allow(missing_docs, unused)]
50 impl From<#collection_class_name> for Vec<#class_name> {
51 fn from(value: #collection_class_name) -> Vec<#class_name> {
52 value.#collection_ident
53 }
54 }
55 }
56 } else {
57 quote! {
58 #[allow(missing_docs, unused)]
59 impl ::osauth::PaginatedResource for #class_name {
60 type Id = #id_type;
61 type Root = Vec<#class_name>;
62 fn resource_id(&self) -> Self::Id {
63 self.#id_name.clone()
64 }
65 }
66 }
67 }
68 .into()
69}
70
71fn get_attr<'a>(attrs: &'a [syn::Attribute], attr: &str) -> Option<&'a syn::Attribute> {
72 attrs.iter().find(|x| x.path.is_ident(attr))
73}
74
75fn get_id_field(input: &syn::DeriveInput) -> syn::Result<(&syn::Ident, &syn::Type)> {
76 let mut default_id = None;
77 if let syn::Data::Struct(ref st) = input.data {
78 if let syn::Fields::Named(ref fs) = st.fields {
79 for field in &fs.named {
80 if get_attr(&field.attrs, "resource_id").is_some() {
81 return Ok((
82 field.ident.as_ref().expect("no ident for resource_id"),
83 &field.ty,
84 ));
85 }
86
87 if let Some(id) = field.ident.as_ref() {
88 if id == "id" {
89 default_id = Some((id, &field.ty));
90 }
91 }
92 }
93 } else {
94 return Err(syn::Error::new_spanned(
95 input,
96 "only named fields are supported for derive(PaginatedResource)",
97 ));
98 }
99 } else {
100 return Err(syn::Error::new_spanned(
101 input,
102 "only structs are supported for derive(PaginatedResource)",
103 ));
104 }
105
106 if let Some(id) = default_id {
107 Ok(id)
108 } else {
109 Err(syn::Error::new_spanned(input, "#[resource_id] missing"))
110 }
111}
112
113fn get_collection_name(input: &syn::DeriveInput) -> syn::Result<Option<String>> {
114 let mut flat = false;
115 let mut maybe_name = None;
116 for attr in &input.attrs {
117 match attr.parse_meta() {
118 Ok(syn::Meta::NameValue(nv)) if nv.path.is_ident("collection_name") => {
119 if flat {
120 return Err(syn::Error::new_spanned(
121 attr,
122 "collection_name and flat_collection cannot be used together",
123 ));
124 }
125 match nv.lit {
126 syn::Lit::Str(s) => maybe_name = Some(s.value()),
127 _ => {
128 return Err(syn::Error::new_spanned(
129 attr,
130 "collection_name must be a string",
131 ))
132 }
133 }
134 }
135 Ok(syn::Meta::Path(p)) if p.is_ident("flat_collection") => {
136 if maybe_name.is_some() {
137 return Err(syn::Error::new_spanned(
138 attr,
139 "collection_name and flat_collection cannot be used together",
140 ));
141 }
142 flat = true;
143 }
144 _ => {}
145 }
146 }
147
148 Ok(if flat {
149 None
150 } else {
151 maybe_name.or_else(|| {
152 let ident = input.ident.to_string().to_case(Case::Snake);
153 Some(
154 if ident.chars().last().expect("empty collection_name") == 's' {
155 format!("{}es", ident)
156 } else {
157 format!("{}s", ident)
158 },
159 )
160 })
161 })
162}
163
164fn fail<S, M>(span: S, message: M) -> TokenStream
165where
166 S: ToTokens,
167 M: fmt::Display,
168{
169 syn::Error::new_spanned(span, message)
170 .into_compile_error()
171 .into()
172}
173
174#[proc_macro_derive(QueryItem, attributes(query_item))]
175pub fn query_item_macro_derive(input: TokenStream) -> TokenStream {
176 let input = syn::parse_macro_input!(input as syn::DeriveInput);
177
178 let class_name = &input.ident;
179 let fragments = match query_item_fragments(
180 class_name,
181 match input.data {
182 syn::Data::Enum(e) => e,
183 _ => {
184 return fail(input, "derive(QueryItem) only works on enums");
185 }
186 },
187 ) {
188 Ok(f) => f,
189 Err(e) => return e.into_compile_error().into(),
190 };
191
192 quote! {
193 impl ::osauth::QueryItem for #class_name {
194 fn query_item(&self) -> ::std::result::Result<(&str, ::std::borrow::Cow<str>), ::osauth::Error> {
195 Ok(match self {
196 #(#fragments),*
197 })
198 }
199 }
200 }.into()
201}
202
203fn query_item_fragments(class_name: &syn::Ident, input: syn::DataEnum) -> syn::Result<Vec<TS2>> {
204 let mut result = Vec::with_capacity(input.variants.len());
205 for var in input.variants {
206 let name = if let Some(attr) = get_attr(&var.attrs, "query_item") {
207 match attr.parse_meta()? {
208 syn::Meta::NameValue(nv) => match nv.lit {
209 syn::Lit::Str(s) => s.value(),
210 _ => {
211 return Err(syn::Error::new_spanned(
212 attr,
213 "query_item value must be a string",
214 ));
215 }
216 },
217 _ => {
218 return Err(syn::Error::new_spanned(
219 attr,
220 "query_item must have a value",
221 ));
222 }
223 }
224 } else {
225 var.ident.to_string().to_case(Case::Snake)
226 };
227 match var.fields {
228 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
229 let field = fields.unnamed.into_iter().next().unwrap();
230 result.push(query_item_fragment(class_name, var.ident, &name, field));
231 }
232 _ => {
233 return Err(syn::Error::new_spanned(
234 var,
235 "each variant must have exactly one unnamed type",
236 ));
237 }
238 }
239 }
240 Ok(result)
241}
242
243fn query_item_fragment(
244 class_name: &syn::Ident,
245 ident: syn::Ident,
246 name: &str,
247 field: syn::Field,
248) -> TS2 {
249 let ty = field.ty;
250 match ty {
251 syn::Type::Path(tp) if tp.qself.is_none() && tp.path.is_ident("String") => {
252 quote! {
253 #class_name::#ident(var) => (#name, ::std::borrow::Cow::Borrowed(var.as_str()))
254 }
255 }
256 syn::Type::Path(tp) if tp.qself.is_none() && tp.path.is_ident("bool") => {
257 quote! {
258 #class_name::#ident(var) => {
259 let value = if *var { "true" } else { "false" };
260 (#name, ::std::borrow::Cow::Borrowed(value))
261 }
262 }
263 }
264 _ => quote! {
265 #class_name::#ident(var) => (#name, var.to_string().into())
266 },
267 }
268}