1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{Attribute, Ident, Type, Variant};
6
7#[proc_macro_derive(Authorization, attributes(pagination, filter, sort, range))]
10pub fn authorization_derive(input: TokenStream) -> TokenStream {
11 let ast: syn::DeriveInput = syn::parse(input).unwrap();
12 impl_authorization_derive(&ast)
13}
14
15#[proc_macro_derive(Oauth2, attributes(pagination, filter, sort, range))]
18pub fn oauth2_derive(input: TokenStream) -> TokenStream {
19 let ast = syn::parse(input).unwrap();
20 impl_oauth2_derive(&ast)
21}
22
23#[proc_macro_derive(Basic, attributes(pagination, filter, sort, range))]
26pub fn basic_derive(input: TokenStream) -> TokenStream {
27 let ast = syn::parse(input).unwrap();
28 impl_basic_derive(&ast)
29}
30
31#[proc_macro_derive(Bearer, attributes(pagination, filter, sort, range))]
34pub fn bearer_derive(input: TokenStream) -> TokenStream {
35 let ast = syn::parse(input).unwrap();
36 impl_bearer_derive(&ast)
37}
38
39#[proc_macro_derive(ApiKey, attributes(pagination, filter, sort, range))]
42pub fn apikey_derive(input: TokenStream) -> TokenStream {
43 let ast = syn::parse(input).unwrap();
44 impl_apikey_derive(&ast)
45}
46
47#[proc_macro_derive(OIDC, attributes(pagination, filter, sort, range))]
50pub fn oidc_derive(input: TokenStream) -> TokenStream {
51 let ast = syn::parse(input).unwrap();
52 impl_oidc_derive(&ast)
53}
54
55#[proc_macro_derive(Keycloak, attributes(auth_type, pagination, filter, sort, range))]
58pub fn keycloak_derive(input: TokenStream) -> TokenStream {
59 let ast = syn::parse(input).unwrap();
60 impl_keycloak_derive(&ast)
61}
62
63fn get_attribute_types(ast: &syn::DeriveInput) -> (Type, Type, Type, Type) {
69 let pagination = ast
70 .attrs
71 .iter()
72 .find(|attr| attr.path().is_ident("pagination"))
73 .and_then(|attr| {
74 if let Attribute {
75 meta: syn::Meta::List(syn::MetaList { tokens: token, .. }),
76 ..
77 } = attr
78 {
79 let name = token.clone().into_iter().next().unwrap().to_string();
80 syn::parse_str::<syn::Type>(&name).ok()
81 } else {
82 None
83 }
84 })
85 .unwrap_or_else(|| syn::parse_str::<syn::Type>("RequestPagination").unwrap());
86 let filter = ast
87 .attrs
88 .iter()
89 .find(|attr| attr.path().is_ident("filter"))
90 .and_then(|attr| {
91 if let Attribute {
92 meta: syn::Meta::List(syn::MetaList { tokens: token, .. }),
93 ..
94 } = attr
95 {
96 let name = token.clone().into_iter().next().unwrap().to_string();
97 syn::parse_str::<syn::Type>(&name).ok()
98 } else {
99 None
100 }
101 })
102 .unwrap_or_else(|| syn::parse_str::<syn::Type>("FilterRule").unwrap());
103 let sort = ast
104 .attrs
105 .iter()
106 .find(|attr| attr.path().is_ident("sort"))
107 .and_then(|attr| {
108 if let Attribute {
109 meta: syn::Meta::List(syn::MetaList { tokens: token, .. }),
110 ..
111 } = attr
112 {
113 let name = token.clone().into_iter().next().unwrap().to_string();
114 syn::parse_str::<syn::Type>(&name).ok()
115 } else {
116 None
117 }
118 })
119 .unwrap_or_else(|| syn::parse_str::<syn::Type>("SortRule").unwrap());
120 let range = ast
121 .attrs
122 .iter()
123 .find(|attr| attr.path().is_ident("range"))
124 .and_then(|attr| {
125 if let Attribute {
126 meta: syn::Meta::List(syn::MetaList { tokens: token, .. }),
127 ..
128 } = attr
129 {
130 let name = token.clone().into_iter().next().unwrap().to_string();
131 syn::parse_str::<syn::Type>(&name).ok()
132 } else {
133 None
134 }
135 })
136 .unwrap_or_else(|| syn::parse_str::<syn::Type>("RangeRule").unwrap());
137 (pagination, filter, sort, range)
138}
139
140fn impl_authorization_derive(ast: &syn::DeriveInput) -> TokenStream {
142 let name = &ast.ident;
143 let (pagination, filter, sort, range) = get_attribute_types(ast);
144 let gen = quote! {
145 impl Authorization<#pagination, #filter, #sort, #range> for #name {}
146 };
147 gen.into()
148}
149
150fn impl_oauth2_derive(ast: &syn::DeriveInput) -> TokenStream {
155 let name = &ast.ident;
156 let (pagination, filter, sort, range) = get_attribute_types(ast);
157 let token_struct_name = syn::Ident::new(&format!("{name}TokenOAuth2"), name.span());
158 let gen = quote! {
159 #[derive(Deserialize)]
160 struct #token_struct_name {
161 access_token: String,
162 }
163 impl Authorization<#pagination, #filter, #sort, #range> for #name {
164 async fn connect(&self, url: &str) -> Result<Api<#pagination, #filter, #sort, #range>> {
165 let connector = ApiBuilder::new(url);
166 let client = Client::new();
167
168 let scopes = self
169 .scopes
170 .iter()
171 .fold(String::new(), |acc, scope| format!("{acc} {scope}"));
172 let mut params = HashMap::new();
173 params.insert("grant_type", "client_credentials");
174 params.insert("client_id", &self.client_id);
175 params.insert("client_secret", &self.client_secret);
176 params.insert("scope", &scopes);
177 match client
178 .post(&self.auth_endpoint)
179 .header("Content-Type", "application/x-www-form-urlencoded")
180 .form(¶ms)
181 .send()
182 .await
183 {
184 Ok(response) => {
185 match response.status() {
186 StatusCode::OK
187 | StatusCode::CREATED
188 | StatusCode::ACCEPTED
189 | StatusCode::NO_CONTENT => {}
190 status => return Err(status.into()),
191 }
192 match response.text().await {
193 Ok(response_text) => {
194 let token: #token_struct_name =
195 serde_json::from_str(&response_text).unwrap();
196 Ok(connector.oauth2(token.access_token).build())
197 }
198 Err(e) => Err(ApiError::ResponseToText(e)),
199 }
200 }
201 Err(e) => Err(ApiError::ReqwestExecute(e)),
202 }
203 }
204 }
205 };
206 gen.into()
207}
208
209fn impl_basic_derive(ast: &syn::DeriveInput) -> TokenStream {
214 let name = &ast.ident;
215 let (pagination, filter, sort, range) = get_attribute_types(ast);
216 let gen = quote! {
217 impl Authorization<#pagination, #filter, #sort, #range> for #name {
218 async fn connect(&self, url: &str) -> Result<Api<#pagination, #filter, #sort, #range>> {
219 let connector = ApiBuilder::new(url);
220 let client = Client::new();
221 let encoded_auth = general_purpose::STANDARD_NO_PAD.encode(format!("{}:{}", &self.login, &self.password));
222
223 Ok(connector.basic(encoded_auth).build())
224 }
225 }
226 };
227 gen.into()
228}
229
230fn impl_bearer_derive(ast: &syn::DeriveInput) -> TokenStream {
235 let name = &ast.ident;
236 let (pagination, filter, sort, range) = get_attribute_types(ast);
237 let gen = quote! {
238 impl Authorization<#pagination, #filter, #sort, #range> for #name {
239 async fn connect(&self, url: &str) -> Result<Api<#pagination, #filter, #sort, #range>> {
240 let connector = ApiBuilder::new(url);
241 let client = Client::new();
242
243 Ok(connector.bearer(&self.secret).build())
244 }
245 }
246 };
247 gen.into()
248}
249
250fn impl_apikey_derive(ast: &syn::DeriveInput) -> TokenStream {
255 let name = &ast.ident;
256 let (pagination, filter, sort, range) = get_attribute_types(ast);
257 let gen = quote! {
258 impl Authorization<#pagination, #filter, #sort, #range> for #name {
259 async fn connect(&self, url: &str) -> Result<Api<#pagination, #filter, #sort, #range>> {
260 let connector = ApiBuilder::new(url);
261 let client = Client::new();
262
263 Ok(connector.apikey(&self.key).build())
264 }
265 }
266 };
267 gen.into()
268}
269
270fn impl_oidc_derive(ast: &syn::DeriveInput) -> TokenStream {
275 let name = &ast.ident;
276 let (pagination, filter, sort, range) = get_attribute_types(ast);
277 let token_struct_name = syn::Ident::new(&format!("{name}TokenOIDC"), name.span());
278 let gen = quote! {
279 #[derive(Deserialize)]
280 struct #token_struct_name {
281 access_token: String,
282 }
283 impl Authorization<#pagination, #filter, #sort, #range> for #name {
284 async fn connect(&self, url: &str) -> Result<Api<#pagination, #filter, #sort, #range>> {
285 let connector = ApiBuilder::new(url);
286 let client = Client::new();
287
288 let scopes = self
289 .scopes
290 .iter()
291 .fold(String::new(), |acc, scope| format!("{acc} {scope}"));
292 let mut params = HashMap::new();
293 params.insert("grant_type", "client_credentials");
294 params.insert("client_id", &self.client_id);
295 params.insert("client_secret", &self.client_secret);
296 params.insert("scope", &scopes);
297 match client
298 .post(&self.auth_endpoint)
299 .header("Content-Type", "application/x-www-form-urlencoded")
300 .form(¶ms)
301 .send()
302 .await
303 {
304 Ok(response) => {
305 match response.status() {
306 StatusCode::OK
307 | StatusCode::CREATED
308 | StatusCode::ACCEPTED
309 | StatusCode::NO_CONTENT => {}
310 status => return Err(status.into()),
311 }
312 match response.text().await {
313 Ok(response_text) => {
314 let token: #token_struct_name =
315 serde_json::from_str(&response_text).unwrap();
316 Ok(connector.oidc(token.access_token).build())
317 }
318 Err(e) => Err(ApiError::ResponseToText(e)),
319 }
320 }
321 Err(e) => Err(ApiError::ReqwestExecute(e)),
322 }
323 }
324 }
325 };
326 gen.into()
327}
328
329fn impl_keycloak_derive(ast: &syn::DeriveInput) -> TokenStream {
331 let Some(auth_type) = ast
332 .attrs
333 .iter()
334 .find(|attr| attr.path().is_ident("auth_type"))
335 .and_then(|attr| {
336 if let Attribute {
337 meta: syn::Meta::List(syn::MetaList { tokens: token, .. }),
338 ..
339 } = attr
340 {
341 let name = token.clone().into_iter().next().unwrap().to_string();
342 syn::parse_str::<Variant>(&name).ok()
343 } else {
344 None
345 }
346 })
347 else {
348 return quote! {
349 compile_error!(
350 "You need to provide an AuthenticationType to Keycloak!"
351 );
352 }
353 .into();
354 };
355 let name = &ast.ident;
356 let (pagination, filter, sort, range) = get_attribute_types(ast);
357 let auth_variant = auth_type.ident;
358 match auth_variant.to_string().as_str() {
359 "None" | "Basic" | "Bearer" | "ApiKey" | "OAuth2" => keycloak_authorization_impl(
360 auth_variant.to_string(),
361 pagination,
362 filter,
363 sort,
364 range,
365 name,
366 ),
367 _ => quote! {
368 compile_error!(
369 "AuthorizationType must be None, Basic, Bearer, ApiKey or OAuth2 !"
370 );
371 }
372 .into(),
373 }
374}
375
376fn keycloak_authorization_impl(
378 auth_type: String,
379 pagination: Type,
380 filter: Type,
381 sort: Type,
382 range: Type,
383 name: &Ident,
384) -> TokenStream {
385 let token_struct_name = syn::Ident::new(&format!("{name}TokenKeycloak"), name.span());
386 let gen = quote! {
387 #[derive(Deserialize)]
388 struct #token_struct_name {
389 access_token: String,
390 }
391 impl Authorization<#pagination, #filter, #sort, #range> for #name {
392 async fn connect(&self, url: &str) -> Result<Api<#pagination, #filter, #sort, #range>> {
393 let connector = ApiBuilder::new(url);
394 let client = Client::new();
395
396 let auth_header = format!(
397 "Basic {}",
398 general_purpose::STANDARD_NO_PAD.encode(format!("{}:{}", &self.client_id, &self.client_secret))
399 );
400 let mut params = HashMap::new();
401 params.insert("grant_type", "password");
402 params.insert("username", &self.user_login);
403 params.insert("password", &self.user_pass);
404 match client
405 .post(format!(
406 "{}realms/{}/protocol/openid-connect/token",
407 self.auth_endpoint, self.realm
408 ))
409 .header("Content-Type", "application/x-www-form-urlencoded")
410 .header("Authorization", auth_header)
411 .form(¶ms)
412 .send()
413 .await
414 {
415 Ok(response) => {
416 log::info!("{:?}", response);
417 match response.status() {
418 StatusCode::OK
419 | StatusCode::CREATED
420 | StatusCode::ACCEPTED
421 | StatusCode::NO_CONTENT => {}
422 status => return Err(status.into()),
423 }
424 match response.text().await {
425 Ok(response_text) => {
426 let token: #token_struct_name =
427 serde_json::from_str(&response_text).unwrap();
428 Ok(connector.keycloak(match #auth_type {
429 "None" => AuthorizationType::None,
430 "Basic" => AuthorizationType::Basic(token.access_token),
431 "Bearer" => AuthorizationType::Bearer(token.access_token),
432 "ApiKey" => AuthorizationType::ApiKey(token.access_token),
433 "OAuth2" => AuthorizationType::OAuth2(token.access_token),
434 _ => return Err(ApiError::AuthorizationType),
435 }).build())
436 }
437 Err(e) => Err(ApiError::ResponseToText(e)),
438 }
439 }
440 Err(e) => Err(ApiError::ReqwestExecute(e)),
441 }
442 }
443 }
444 };
445 gen.into()
446}