retroqwest_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{quote, quote_spanned};
4use std::convert::TryFrom;
5use syn::spanned::Spanned;
6use syn::{
7    parse_macro_input, parse_quote, Attribute, FnArg, Ident, ImplItemMethod, ItemTrait, LitStr,
8    Pat, PatIdent, PatType, Signature, TraitItem, TraitItemMethod,
9};
10
11#[proc_macro_attribute]
12pub fn retroqwest(_args: TokenStream, input: TokenStream) -> TokenStream {
13    let item = parse_macro_input!(input as ItemTrait);
14
15    expand(item).unwrap_or_else(to_compile_errors).into()
16}
17
18fn to_compile_errors(errors: syn::Error) -> proc_macro2::TokenStream {
19    let compile_errors = errors.to_compile_error();
20    compile_errors
21}
22
23struct HttpMethodAttribute {
24    method: Ident,
25    _response: Option<Ident>,
26    uri: LitStr,
27}
28
29impl TryFrom<Attribute> for HttpMethodAttribute {
30    type Error = syn::Error;
31    fn try_from(att: Attribute) -> Result<Self, Self::Error> {
32        let mut segments = att.path.segments.iter();
33
34        segments.next().expect("the http segment has to exist");
35
36        let uri: LitStr = att.parse_args()?;
37
38        let path_seg = segments
39            .next()
40            .ok_or(syn::Error::new(att.span(), "http attribute missing method"))?;
41
42        Ok(Self {
43            method: path_seg.ident.clone(),
44            _response: None,
45            uri,
46        })
47    }
48}
49
50fn get_att(attrs: &mut Vec<Attribute>, name: &'static str) -> Option<Attribute> {
51    attrs
52        .iter()
53        .enumerate()
54        .find_map(move |(i, a)| {
55            a.path
56                .segments
57                .first()
58                .filter(|p| p.ident == name)
59                .map(move |_| i)
60        })
61        .map(|i| attrs.remove(i))
62}
63
64enum HttpArg {
65    JsonBody {
66        arg: Ident,
67        span: Span,
68    },
69    Uri {
70        arg: Ident,
71        span: Span,
72    },
73    Query {
74        name: LitStr,
75        arg: Ident,
76        span: Span,
77    },
78}
79
80impl HttpArg {
81    fn parse(arg: &mut FnArg) -> Option<Self> {
82        match arg {
83            FnArg::Typed(PatType { attrs, pat, .. }) => {
84                if let Pat::Ident(PatIdent { ident, .. }) = pat.as_ref() {
85                    if let Some(_json_att) = get_att(attrs, "json") {
86                        Some(HttpArg::JsonBody {
87                            arg: ident.clone(),
88                            span: ident.span(),
89                        })
90                    } else if let Some(_query_att) = get_att(attrs, "query") {
91                        Some(HttpArg::Query {
92                            name: LitStr::new(ident.to_string().as_str(), ident.span()),
93                            arg: ident.clone(),
94                            span: ident.span(),
95                        })
96                    } else {
97                        Some(HttpArg::Uri {
98                            arg: ident.clone(),
99                            span: ident.span(),
100                        })
101                    }
102                } else {
103                    None
104                }
105            }
106            _ => None,
107        }
108    }
109}
110
111fn build_method(
112    attrs: &mut Vec<Attribute>,
113    sig: &mut Signature,
114) -> Result<ImplItemMethod, syn::Error> {
115    let attr = get_att(attrs, "http")
116        .ok_or(syn::Error::new(sig.span(), "Missing http method attribute"))?;
117    let att = HttpMethodAttribute::try_from(attr)?;
118
119    let args = sig
120        .inputs
121        .iter_mut()
122        .filter_map(HttpArg::parse)
123        .collect::<Vec<_>>();
124
125    let uri_args = args.iter().filter_map(|a| match a {
126        HttpArg::Uri { arg, span } => Some(quote_spanned!(*span=> #arg = #arg)),
127        _ => None,
128    });
129
130    let query_args = args
131        .iter()
132        .filter_map(|a| match a {
133            HttpArg::Query { name, arg, span } => Some(quote_spanned!(*span=> (#name, format!("{}", #arg)))),
134            _ => None,
135        })
136        .collect::<Vec<_>>();
137
138    let query = if query_args.is_empty() {
139        None
140    } else {
141        Some(quote! { .query(&[#(#query_args, )*]) })
142    };
143
144    let body_args = args.iter().filter_map(|a| match a {
145        HttpArg::JsonBody { arg, span } => Some(quote_spanned!(*span=> .json(#arg))),
146        _ => None,
147    });
148
149    let uri = att.uri;
150    let method = att.method;
151
152    let uri = quote_spanned!(uri.span()=>concat!("{}", #uri));
153
154    Ok(parse_quote! {
155      #(#attrs)*
156      #sig {
157       Ok(self.client.#method(format!(#uri, self.endpoint#(, #uri_args)*))
158          #query
159          #(#body_args)*
160          .send().await.map_err(retroqwest::RetroqwestError::RequestError)?
161          .error_for_status().map_err(|source| retroqwest::RetroqwestError::ResponseError {
162            status: source.status().unwrap(),
163            source
164          })?
165          .json().await.map_err(retroqwest::RetroqwestError::JsonParse)?)
166      }
167    })
168}
169
170fn expand(mut def: ItemTrait) -> Result<proc_macro2::TokenStream, syn::Error> {
171    let trait_name = &def.ident;
172    let name = Ident::new(&format!("{}Client", trait_name), def.ident.span());
173    let vis = &def.vis;
174
175    let mut methods: Vec<ImplItemMethod> = vec![];
176
177    for member in &mut def.items {
178        match member {
179            TraitItem::Method(TraitItemMethod {
180                attrs,
181                sig,
182                default,
183                ..
184            }) => {
185                if default.is_some() {
186                    return Err(syn::Error::new(
187                        default.as_ref().unwrap().span(),
188                        "retroquest trait methods cannot have defaults",
189                    ));
190                }
191
192                methods.push(build_method(attrs, sig)?)
193            }
194            a => return Err(syn::Error::new(a.span(), "Only trait methods are supported on a retroqwest trait")),
195        }
196    }
197
198    let client = quote! {
199      #[derive(Clone, Debug)]
200      #vis struct #name {
201        endpoint: String,
202        client: retroqwest::reqwest::Client,
203      }
204
205      #[async_trait::async_trait]
206      impl #trait_name for #name {
207          #(#methods)*
208      }
209
210      impl #name {
211        fn from_builder<T: Into<String>>(
212          base_url: T,
213          client_builder: retroqwest::reqwest::ClientBuilder)
214        -> Result<Self, retroqwest::RetroqwestError>  {
215          Ok(Self {
216            endpoint: base_url.into().trim_end_matches('/').to_string(),
217            client: client_builder.build().map_err(retroqwest::RetroqwestError::FailedToBuildClient)?
218          })
219        }
220      }
221    };
222
223    def.attrs
224        .push(parse_quote!(#[retroqwest::async_trait::async_trait]));
225
226    Ok(quote! {
227      #def
228
229      #client
230    })
231}