1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{quote, quote_spanned};
4use std::convert::TryFrom;
5use syn::spanned::Spanned;
6use syn::{
7 parse_macro_input, parse_quote, Attribute, FnArg, Ident, ImplItemMethod, ItemTrait, LitStr,
8 Pat, PatIdent, PatType, Signature, TraitItem, TraitItemMethod,
9};
10
11#[proc_macro_attribute]
12pub fn retroqwest(_args: TokenStream, input: TokenStream) -> TokenStream {
13 let item = parse_macro_input!(input as ItemTrait);
14
15 expand(item).unwrap_or_else(to_compile_errors).into()
16}
17
18fn to_compile_errors(errors: syn::Error) -> proc_macro2::TokenStream {
19 let compile_errors = errors.to_compile_error();
20 compile_errors
21}
22
23struct HttpMethodAttribute {
24 method: Ident,
25 _response: Option<Ident>,
26 uri: LitStr,
27}
28
29impl TryFrom<Attribute> for HttpMethodAttribute {
30 type Error = syn::Error;
31 fn try_from(att: Attribute) -> Result<Self, Self::Error> {
32 let mut segments = att.path.segments.iter();
33
34 segments.next().expect("the http segment has to exist");
35
36 let uri: LitStr = att.parse_args()?;
37
38 let path_seg = segments
39 .next()
40 .ok_or(syn::Error::new(att.span(), "http attribute missing method"))?;
41
42 Ok(Self {
43 method: path_seg.ident.clone(),
44 _response: None,
45 uri,
46 })
47 }
48}
49
50fn get_att(attrs: &mut Vec<Attribute>, name: &'static str) -> Option<Attribute> {
51 attrs
52 .iter()
53 .enumerate()
54 .find_map(move |(i, a)| {
55 a.path
56 .segments
57 .first()
58 .filter(|p| p.ident == name)
59 .map(move |_| i)
60 })
61 .map(|i| attrs.remove(i))
62}
63
64enum HttpArg {
65 JsonBody {
66 arg: Ident,
67 span: Span,
68 },
69 Uri {
70 arg: Ident,
71 span: Span,
72 },
73 Query {
74 name: LitStr,
75 arg: Ident,
76 span: Span,
77 },
78}
79
80impl HttpArg {
81 fn parse(arg: &mut FnArg) -> Option<Self> {
82 match arg {
83 FnArg::Typed(PatType { attrs, pat, .. }) => {
84 if let Pat::Ident(PatIdent { ident, .. }) = pat.as_ref() {
85 if let Some(_json_att) = get_att(attrs, "json") {
86 Some(HttpArg::JsonBody {
87 arg: ident.clone(),
88 span: ident.span(),
89 })
90 } else if let Some(_query_att) = get_att(attrs, "query") {
91 Some(HttpArg::Query {
92 name: LitStr::new(ident.to_string().as_str(), ident.span()),
93 arg: ident.clone(),
94 span: ident.span(),
95 })
96 } else {
97 Some(HttpArg::Uri {
98 arg: ident.clone(),
99 span: ident.span(),
100 })
101 }
102 } else {
103 None
104 }
105 }
106 _ => None,
107 }
108 }
109}
110
111fn build_method(
112 attrs: &mut Vec<Attribute>,
113 sig: &mut Signature,
114) -> Result<ImplItemMethod, syn::Error> {
115 let attr = get_att(attrs, "http")
116 .ok_or(syn::Error::new(sig.span(), "Missing http method attribute"))?;
117 let att = HttpMethodAttribute::try_from(attr)?;
118
119 let args = sig
120 .inputs
121 .iter_mut()
122 .filter_map(HttpArg::parse)
123 .collect::<Vec<_>>();
124
125 let uri_args = args.iter().filter_map(|a| match a {
126 HttpArg::Uri { arg, span } => Some(quote_spanned!(*span=> #arg = #arg)),
127 _ => None,
128 });
129
130 let query_args = args
131 .iter()
132 .filter_map(|a| match a {
133 HttpArg::Query { name, arg, span } => Some(quote_spanned!(*span=> (#name, format!("{}", #arg)))),
134 _ => None,
135 })
136 .collect::<Vec<_>>();
137
138 let query = if query_args.is_empty() {
139 None
140 } else {
141 Some(quote! { .query(&[#(#query_args, )*]) })
142 };
143
144 let body_args = args.iter().filter_map(|a| match a {
145 HttpArg::JsonBody { arg, span } => Some(quote_spanned!(*span=> .json(#arg))),
146 _ => None,
147 });
148
149 let uri = att.uri;
150 let method = att.method;
151
152 let uri = quote_spanned!(uri.span()=>concat!("{}", #uri));
153
154 Ok(parse_quote! {
155 #(#attrs)*
156 #sig {
157 Ok(self.client.#method(format!(#uri, self.endpoint#(, #uri_args)*))
158 #query
159 #(#body_args)*
160 .send().await.map_err(retroqwest::RetroqwestError::RequestError)?
161 .error_for_status().map_err(|source| retroqwest::RetroqwestError::ResponseError {
162 status: source.status().unwrap(),
163 source
164 })?
165 .json().await.map_err(retroqwest::RetroqwestError::JsonParse)?)
166 }
167 })
168}
169
170fn expand(mut def: ItemTrait) -> Result<proc_macro2::TokenStream, syn::Error> {
171 let trait_name = &def.ident;
172 let name = Ident::new(&format!("{}Client", trait_name), def.ident.span());
173 let vis = &def.vis;
174
175 let mut methods: Vec<ImplItemMethod> = vec![];
176
177 for member in &mut def.items {
178 match member {
179 TraitItem::Method(TraitItemMethod {
180 attrs,
181 sig,
182 default,
183 ..
184 }) => {
185 if default.is_some() {
186 return Err(syn::Error::new(
187 default.as_ref().unwrap().span(),
188 "retroquest trait methods cannot have defaults",
189 ));
190 }
191
192 methods.push(build_method(attrs, sig)?)
193 }
194 a => return Err(syn::Error::new(a.span(), "Only trait methods are supported on a retroqwest trait")),
195 }
196 }
197
198 let client = quote! {
199 #[derive(Clone, Debug)]
200 #vis struct #name {
201 endpoint: String,
202 client: retroqwest::reqwest::Client,
203 }
204
205 #[async_trait::async_trait]
206 impl #trait_name for #name {
207 #(#methods)*
208 }
209
210 impl #name {
211 fn from_builder<T: Into<String>>(
212 base_url: T,
213 client_builder: retroqwest::reqwest::ClientBuilder)
214 -> Result<Self, retroqwest::RetroqwestError> {
215 Ok(Self {
216 endpoint: base_url.into().trim_end_matches('/').to_string(),
217 client: client_builder.build().map_err(retroqwest::RetroqwestError::FailedToBuildClient)?
218 })
219 }
220 }
221 };
222
223 def.attrs
224 .push(parse_quote!(#[retroqwest::async_trait::async_trait]));
225
226 Ok(quote! {
227 #def
228
229 #client
230 })
231}