authorization_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{Attribute, Ident, Type, Variant};
6
7/// The derive macro #[derive(Authorization)] is used to implement the Authorization trait by default for a struct.\
8/// The trait will not add any authorization to the Api by default.
9#[proc_macro_derive(Authorization, attributes(pagination, filter, sort, range))]
10pub fn authorization_derive(input: TokenStream) -> TokenStream {
11    let ast: syn::DeriveInput = syn::parse(input).unwrap();
12    impl_authorization_derive(&ast)
13}
14
15/// The derive macro #[derive(Oauth2)] is used to implement the Authorization trait for a struct.\
16/// The trait will add OAuth2 authorization to the Api.
17#[proc_macro_derive(Oauth2, attributes(pagination, filter, sort, range))]
18pub fn oauth2_derive(input: TokenStream) -> TokenStream {
19    let ast = syn::parse(input).unwrap();
20    impl_oauth2_derive(&ast)
21}
22
23/// The derive macro #[derive(Basic)] is used to implement the Authorization trait for a struct.\
24/// The trait will add Basic authorization to the Api.
25#[proc_macro_derive(Basic, attributes(pagination, filter, sort, range))]
26pub fn basic_derive(input: TokenStream) -> TokenStream {
27    let ast = syn::parse(input).unwrap();
28    impl_basic_derive(&ast)
29}
30
31/// The derive macro #[derive(Bearer)] is used to implement the Authorization trait for a struct.\
32/// The trait will add Bearer authorization to the Api.
33#[proc_macro_derive(Bearer, attributes(pagination, filter, sort, range))]
34pub fn bearer_derive(input: TokenStream) -> TokenStream {
35    let ast = syn::parse(input).unwrap();
36    impl_bearer_derive(&ast)
37}
38
39/// The derive macro #[derive(ApiKey)] is used to implement the Authorization trait for a struct.\
40/// The trait will add ApiKey authorization to the Api.
41#[proc_macro_derive(ApiKey, attributes(pagination, filter, sort, range))]
42pub fn apikey_derive(input: TokenStream) -> TokenStream {
43    let ast = syn::parse(input).unwrap();
44    impl_apikey_derive(&ast)
45}
46
47/// The derive macro #[derive(OIDC)] is used to implement the Authorization trait for a struct.\
48/// The trait will add OIDC authorization to the Api.
49#[proc_macro_derive(OIDC, attributes(pagination, filter, sort, range))]
50pub fn oidc_derive(input: TokenStream) -> TokenStream {
51    let ast = syn::parse(input).unwrap();
52    impl_oidc_derive(&ast)
53}
54
55/// The derive macro #[derive(Keycloak)] is used to implement the Authorization trait for a struct.\
56/// The trait will add the AuthorizationType authorization to the Api and will use the Keycloak service.
57#[proc_macro_derive(Keycloak, attributes(auth_type, pagination, filter, sort, range))]
58pub fn keycloak_derive(input: TokenStream) -> TokenStream {
59    let ast = syn::parse(input).unwrap();
60    impl_keycloak_derive(&ast)
61}
62
63/// Function to parse generic types for the Authorization implementation
64/// - Pagination
65/// - Filter
66/// - Sort
67/// - Range
68fn get_attribute_types(ast: &syn::DeriveInput) -> (Type, Type, Type, Type) {
69    let pagination = ast
70        .attrs
71        .iter()
72        .find(|attr| attr.path().is_ident("pagination"))
73        .and_then(|attr| {
74            if let Attribute {
75                meta: syn::Meta::List(syn::MetaList { tokens: token, .. }),
76                ..
77            } = attr
78            {
79                let name = token.clone().into_iter().next().unwrap().to_string();
80                syn::parse_str::<syn::Type>(&name).ok()
81            } else {
82                None
83            }
84        })
85        .unwrap_or_else(|| syn::parse_str::<syn::Type>("RequestPagination").unwrap());
86    let filter = ast
87        .attrs
88        .iter()
89        .find(|attr| attr.path().is_ident("filter"))
90        .and_then(|attr| {
91            if let Attribute {
92                meta: syn::Meta::List(syn::MetaList { tokens: token, .. }),
93                ..
94            } = attr
95            {
96                let name = token.clone().into_iter().next().unwrap().to_string();
97                syn::parse_str::<syn::Type>(&name).ok()
98            } else {
99                None
100            }
101        })
102        .unwrap_or_else(|| syn::parse_str::<syn::Type>("FilterRule").unwrap());
103    let sort = ast
104        .attrs
105        .iter()
106        .find(|attr| attr.path().is_ident("sort"))
107        .and_then(|attr| {
108            if let Attribute {
109                meta: syn::Meta::List(syn::MetaList { tokens: token, .. }),
110                ..
111            } = attr
112            {
113                let name = token.clone().into_iter().next().unwrap().to_string();
114                syn::parse_str::<syn::Type>(&name).ok()
115            } else {
116                None
117            }
118        })
119        .unwrap_or_else(|| syn::parse_str::<syn::Type>("SortRule").unwrap());
120    let range = ast
121        .attrs
122        .iter()
123        .find(|attr| attr.path().is_ident("range"))
124        .and_then(|attr| {
125            if let Attribute {
126                meta: syn::Meta::List(syn::MetaList { tokens: token, .. }),
127                ..
128            } = attr
129            {
130                let name = token.clone().into_iter().next().unwrap().to_string();
131                syn::parse_str::<syn::Type>(&name).ok()
132            } else {
133                None
134            }
135        })
136        .unwrap_or_else(|| syn::parse_str::<syn::Type>("RangeRule").unwrap());
137    (pagination, filter, sort, range)
138}
139
140/// Only impl the Authorization trait for the struct, with the default implementation.
141fn impl_authorization_derive(ast: &syn::DeriveInput) -> TokenStream {
142    let name = &ast.ident;
143    let (pagination, filter, sort, range) = get_attribute_types(ast);
144    let gen = quote! {
145        impl Authorization<#pagination, #filter, #sort, #range> for #name {}
146    };
147    gen.into()
148}
149
150/// Impl the Authorization trait for the struct, with the OAuth2 implementation.\
151/// The trait accept the pagination, filter, sort and range types as attributes. (Optionals)\
152/// We use the AST to find the attributes (pagination, filter, sort and range) and parse them to the correct type.\
153/// If the attribute is not found, we use the default type.
154fn impl_oauth2_derive(ast: &syn::DeriveInput) -> TokenStream {
155    let name = &ast.ident;
156    let (pagination, filter, sort, range) = get_attribute_types(ast);
157    let token_struct_name = syn::Ident::new(&format!("{name}TokenOAuth2"), name.span());
158    let gen = quote! {
159        #[derive(Deserialize)]
160        struct #token_struct_name {
161            access_token: String,
162        }
163        impl Authorization<#pagination, #filter, #sort, #range> for #name {
164            async fn connect(&self, url: &str) -> Result<Api<#pagination, #filter, #sort, #range>> {
165                let connector = ApiBuilder::new(url);
166                let client = Client::new();
167
168                let scopes = self
169                    .scopes
170                    .iter()
171                    .fold(String::new(), |acc, scope| format!("{acc} {scope}"));
172                let mut params = HashMap::new();
173                params.insert("grant_type", "client_credentials");
174                params.insert("client_id", &self.client_id);
175                params.insert("client_secret", &self.client_secret);
176                params.insert("scope", &scopes);
177                match client
178                    .post(&self.auth_endpoint)
179                    .header("Content-Type", "application/x-www-form-urlencoded")
180                    .form(&params)
181                    .send()
182                    .await
183                {
184                    Ok(response) => {
185                        match response.status() {
186                            StatusCode::OK
187                            | StatusCode::CREATED
188                            | StatusCode::ACCEPTED
189                            | StatusCode::NO_CONTENT => {}
190                            status => return Err(status.into()),
191                        }
192                        match response.text().await {
193                            Ok(response_text) => {
194                                let token: #token_struct_name =
195                                    serde_json::from_str(&response_text).unwrap();
196                                Ok(connector.oauth2(token.access_token).build())
197                            }
198                            Err(e) => Err(ApiError::ResponseToText(e)),
199                        }
200                    }
201                    Err(e) => Err(ApiError::ReqwestExecute(e)),
202                }
203            }
204        }
205    };
206    gen.into()
207}
208
209/// Impl the Authorization trait for the struct, with the Basic implementation.\
210/// The trait accept the pagination, filter, sort and range types as attributes. (Optionals)\
211/// We use the AST to find the attributes (pagination, filter, sort and range) and parse them to the correct type.\
212/// If the attribute is not found, we use the default type.
213fn impl_basic_derive(ast: &syn::DeriveInput) -> TokenStream {
214    let name = &ast.ident;
215    let (pagination, filter, sort, range) = get_attribute_types(ast);
216    let gen = quote! {
217        impl Authorization<#pagination, #filter, #sort, #range> for #name {
218            async fn connect(&self, url: &str) -> Result<Api<#pagination, #filter, #sort, #range>> {
219                let connector = ApiBuilder::new(url);
220                let client = Client::new();
221                let encoded_auth = general_purpose::STANDARD_NO_PAD.encode(format!("{}:{}", &self.login, &self.password));
222
223                Ok(connector.basic(encoded_auth).build())
224            }
225        }
226    };
227    gen.into()
228}
229
230/// Impl the Authorization trait for the struct, with the Bearer implementation.\
231/// The trait accept the pagination, filter, sort and range types as attributes. (Optionals)\
232/// We use the AST to find the attributes (pagination, filter, sort and range) and parse them to the correct type.\
233/// If the attribute is not found, we use the default type.
234fn impl_bearer_derive(ast: &syn::DeriveInput) -> TokenStream {
235    let name = &ast.ident;
236    let (pagination, filter, sort, range) = get_attribute_types(ast);
237    let gen = quote! {
238        impl Authorization<#pagination, #filter, #sort, #range> for #name {
239            async fn connect(&self, url: &str) -> Result<Api<#pagination, #filter, #sort, #range>> {
240                let connector = ApiBuilder::new(url);
241                let client = Client::new();
242
243                Ok(connector.bearer(&self.secret).build())
244            }
245        }
246    };
247    gen.into()
248}
249
250/// Impl the Authorization trait for the struct, with the ApiKey implementation.\
251/// The trait accept the pagination, filter, sort and range types as attributes. (Optionals)\
252/// We use the AST to find the attributes (pagination, filter, sort and range) and parse them to the correct type.\
253/// If the attribute is not found, we use the default type.
254fn impl_apikey_derive(ast: &syn::DeriveInput) -> TokenStream {
255    let name = &ast.ident;
256    let (pagination, filter, sort, range) = get_attribute_types(ast);
257    let gen = quote! {
258        impl Authorization<#pagination, #filter, #sort, #range> for #name {
259            async fn connect(&self, url: &str) -> Result<Api<#pagination, #filter, #sort, #range>> {
260                let connector = ApiBuilder::new(url);
261                let client = Client::new();
262
263                Ok(connector.apikey(&self.key).build())
264            }
265        }
266    };
267    gen.into()
268}
269
270/// Impl the Authorization trait for the struct, with the OIDC implementation.\
271/// The trait accept the pagination, filter, sort and range types as attributes. (Optionals)\
272/// We use the AST to find the attributes (pagination, filter, sort and range) and parse them to the correct type.\
273/// If the attribute is not found, we use the default type.
274fn impl_oidc_derive(ast: &syn::DeriveInput) -> TokenStream {
275    let name = &ast.ident;
276    let (pagination, filter, sort, range) = get_attribute_types(ast);
277    let token_struct_name = syn::Ident::new(&format!("{name}TokenOIDC"), name.span());
278    let gen = quote! {
279        #[derive(Deserialize)]
280        struct #token_struct_name {
281            access_token: String,
282        }
283        impl Authorization<#pagination, #filter, #sort, #range> for #name {
284            async fn connect(&self, url: &str) -> Result<Api<#pagination, #filter, #sort, #range>> {
285                let connector = ApiBuilder::new(url);
286                let client = Client::new();
287
288                let scopes = self
289                    .scopes
290                    .iter()
291                    .fold(String::new(), |acc, scope| format!("{acc} {scope}"));
292                let mut params = HashMap::new();
293                params.insert("grant_type", "client_credentials");
294                params.insert("client_id", &self.client_id);
295                params.insert("client_secret", &self.client_secret);
296                params.insert("scope", &scopes);
297                match client
298                    .post(&self.auth_endpoint)
299                    .header("Content-Type", "application/x-www-form-urlencoded")
300                    .form(&params)
301                    .send()
302                    .await
303                {
304                    Ok(response) => {
305                        match response.status() {
306                            StatusCode::OK
307                            | StatusCode::CREATED
308                            | StatusCode::ACCEPTED
309                            | StatusCode::NO_CONTENT => {}
310                            status => return Err(status.into()),
311                        }
312                        match response.text().await {
313                            Ok(response_text) => {
314                                let token: #token_struct_name =
315                                    serde_json::from_str(&response_text).unwrap();
316                                Ok(connector.oidc(token.access_token).build())
317                            }
318                            Err(e) => Err(ApiError::ResponseToText(e)),
319                        }
320                    }
321                    Err(e) => Err(ApiError::ReqwestExecute(e)),
322                }
323            }
324        }
325    };
326    gen.into()
327}
328
329/// Impl the Authorization trait for the struct, with the Keycloak implementation.
330fn impl_keycloak_derive(ast: &syn::DeriveInput) -> TokenStream {
331    let Some(auth_type) = ast
332        .attrs
333        .iter()
334        .find(|attr| attr.path().is_ident("auth_type"))
335        .and_then(|attr| {
336            if let Attribute {
337                meta: syn::Meta::List(syn::MetaList { tokens: token, .. }),
338                ..
339            } = attr
340            {
341                let name = token.clone().into_iter().next().unwrap().to_string();
342                syn::parse_str::<Variant>(&name).ok()
343            } else {
344                None
345            }
346        })
347    else {
348        return quote! {
349            compile_error!(
350                "You need to provide an AuthenticationType to Keycloak!"
351            );
352        }
353        .into();
354    };
355    let name = &ast.ident;
356    let (pagination, filter, sort, range) = get_attribute_types(ast);
357    let auth_variant = auth_type.ident;
358    match auth_variant.to_string().as_str() {
359        "None" | "Basic" | "Bearer" | "ApiKey" | "OAuth2" => keycloak_authorization_impl(
360            auth_variant.to_string(),
361            pagination,
362            filter,
363            sort,
364            range,
365            name,
366        ),
367        _ => quote! {
368            compile_error!(
369                "AuthorizationType must be None, Basic, Bearer, ApiKey or OAuth2 !"
370            );
371        }
372        .into(),
373    }
374}
375
376/// Impl the Authorization trait for the struct, with the Keycloak implementation.
377fn keycloak_authorization_impl(
378    auth_type: String,
379    pagination: Type,
380    filter: Type,
381    sort: Type,
382    range: Type,
383    name: &Ident,
384) -> TokenStream {
385    let token_struct_name = syn::Ident::new(&format!("{name}TokenKeycloak"), name.span());
386    let gen = quote! {
387        #[derive(Deserialize)]
388        struct #token_struct_name {
389            access_token: String,
390        }
391        impl Authorization<#pagination, #filter, #sort, #range> for #name {
392            async fn connect(&self, url: &str) -> Result<Api<#pagination, #filter, #sort, #range>> {
393                let connector = ApiBuilder::new(url);
394                let client = Client::new();
395
396                let auth_header = format!(
397                    "Basic {}",
398                    general_purpose::STANDARD_NO_PAD.encode(format!("{}:{}", &self.client_id, &self.client_secret))
399                );
400                let mut params = HashMap::new();
401                params.insert("grant_type", "password");
402                params.insert("username", &self.user_login);
403                params.insert("password", &self.user_pass);
404                match client
405                    .post(format!(
406                        "{}realms/{}/protocol/openid-connect/token",
407                        self.auth_endpoint, self.realm
408                    ))
409                    .header("Content-Type", "application/x-www-form-urlencoded")
410                    .header("Authorization", auth_header)
411                    .form(&params)
412                    .send()
413                    .await
414                {
415                    Ok(response) => {
416                        log::info!("{:?}", response);
417                        match response.status() {
418                            StatusCode::OK
419                            | StatusCode::CREATED
420                            | StatusCode::ACCEPTED
421                            | StatusCode::NO_CONTENT => {}
422                            status => return Err(status.into()),
423                        }
424                        match response.text().await {
425                            Ok(response_text) => {
426                                let token: #token_struct_name =
427                                    serde_json::from_str(&response_text).unwrap();
428                                Ok(connector.keycloak(match #auth_type {
429                                    "None" => AuthorizationType::None,
430                                    "Basic" => AuthorizationType::Basic(token.access_token),
431                                    "Bearer" => AuthorizationType::Bearer(token.access_token),
432                                    "ApiKey" => AuthorizationType::ApiKey(token.access_token),
433                                    "OAuth2" => AuthorizationType::OAuth2(token.access_token),
434                                    _ => return Err(ApiError::AuthorizationType),
435                                }).build())
436                            }
437                            Err(e) => Err(ApiError::ResponseToText(e)),
438                        }
439                    }
440                    Err(e) => Err(ApiError::ReqwestExecute(e)),
441                }
442            }
443        }
444    };
445    gen.into()
446}