function_compose_proc_macros/
lib.rs

1use proc_macro::TokenStream;
2
3use std::fmt::Formatter;
4use std::{fmt::Display, ops::Deref};
5
6use quote::{quote, ToTokens, TokenStreamExt};
7use syn::parse::ParseStream;
8use syn::{parse::Parse, Expr, FnArg, ItemFn, ReturnType, Token, Type};
9
10use crate::OptionalRetry::SomeRetry;
11
12fn generate_return_type_param(index: u8) -> String {
13    format!("T{index}")
14}
15
16mod keyword {
17    syn::custom_keyword!(retry);
18}
19
20fn generate_generics_parameters(count: u8) -> String {
21    let mut result: String = "".to_owned();
22    for i in 1..=count {
23        result.push_str(format!("T{},", i.to_string().as_str()).as_str());
24    }
25    result
26}
27
28#[proc_macro_attribute]
29pub fn retry(_attr: TokenStream, _item: TokenStream) -> TokenStream {
30    panic!()
31}
32
33struct Retry {
34    strategy: Expr,
35}
36
37struct FunctionArgs<'a> {
38    args: Vec<&'a FnArg>,
39}
40
41struct FunctionMutArgs<'a> {
42    args: Vec<&'a FnArg>,
43}
44
45impl<'a> ToTokens for FunctionArgs<'a> {
46    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
47        self.args.iter().for_each(|arg| {
48            match *arg {
49                FnArg::Receiver(_) => {}
50                FnArg::Typed(t) => {
51                    let ident = &t.pat;
52                    let ty = t.ty.as_ref();
53                    let token_stream = match ty {
54                        Type::Reference(reference) => {
55                            if reference.mutability.is_some() {
56                                quote! {
57                                   #[allow(unused)] &mut #ident,
58                                }
59                            } else {
60                                quote! {
61                                    #[allow(unused)] & #ident,
62                                }
63                            }
64                        }
65                        _ => {
66                            quote! {
67                                 #[allow(unused)] #ident,
68                            }
69                        }
70                    };
71
72                    tokens.append_all(token_stream.into_iter());
73                    /*let mut toeknStream: proc_macro::TokenStream = tokeStream.into();
74                    toeknStream.extend(tokens.into_iter())*/
75                }
76            }
77        });
78    }
79}
80
81impl<'a> ToTokens for FunctionMutArgs<'a> {
82    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
83        self.args.iter().for_each(|arg| {
84            let token_stream = quote! {
85                       #[allow(unused)] mut #arg,
86            };
87            tokens.append_all(token_stream.into_iter());
88        });
89    }
90}
91
92enum OptionalRetry {
93    SomeRetry(Retry),
94    NoRetry,
95}
96
97impl Display for Retry {
98    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
99        write!(f, "{}", "test")
100    }
101}
102
103impl Parse for OptionalRetry {
104    fn parse(input: ParseStream) -> syn::Result<Self> {
105        let lookahead = input.lookahead1();
106        if lookahead.peek(keyword::retry) {
107            let _ = input.parse::<keyword::retry>();
108            let _ = input.parse::<Token![=]>();
109            let expr: Expr = input.parse()?;
110            println!("{}", expr.to_token_stream());
111            Ok(OptionalRetry::SomeRetry(Retry { strategy: expr }))
112        } else {
113            Ok(OptionalRetry::NoRetry)
114        }
115    }
116}
117
118
119fn generate_ident_with_prefix(ident: &str) -> String{
120    format!("fn_composer__{}", ident)
121}
122
123
124#[proc_macro_attribute]
125pub fn composeable(attr: TokenStream, item: TokenStream) -> TokenStream {
126    
127
128    let token_stream_clone = item.clone();
129    let item_fn: ItemFn = syn::parse_macro_input!(token_stream_clone);
130
131    let fn_gen = item_fn.sig.generics;
132    let mut async_fn = item_fn.sig.asyncness.is_some();
133    let input_args = item_fn.sig.inputs;
134    let arg_tokens: Vec<_> = input_args.iter().collect();
135    let mut_arg_tokens: Vec<_> = input_args.iter().collect();
136    let arg_length = input_args.len();
137    let fn_ident = &item_fn.sig.ident;
138    let fn_name = item_fn.sig.ident.to_string();
139    let fn_return_type = &item_fn.sig.output;
140
141    let return_type_without_token = match fn_return_type {
142        ReturnType::Default => None,
143        ReturnType::Type(_, return_type) => Some(return_type),
144    };
145
146    let retry = syn::parse_macro_input!(attr as OptionalRetry);
147
148    if !async_fn {
149        match fn_return_type {
150            syn::ReturnType::Default => {}
151            syn::ReturnType::Type(_, t) => {
152                let x = t.deref();
153                async_fn = x.to_token_stream().to_string().starts_with("BoxFuture");
154            }
155        }
156    }
157    let lifted_fn_name = "lifted_fn_".to_owned() + &fn_name;
158    let prefixed_lifted_fn_name = &generate_ident_with_prefix(&lifted_fn_name);
159    //let lift_retry_fn_name = &generate_ident_with_prefix(&("retry_".to_owned() + &lifted_fn_name));
160    let lift_fn_ident = syn::Ident::new(prefixed_lifted_fn_name, proc_macro2::Span::call_site());
161    let is_retry_fn_name = &generate_ident_with_prefix(&("is_retryable_".to_owned() + &fn_name));
162    let is_retry_fn_ident = syn::Ident::new(&is_retry_fn_name, proc_macro2::Span::call_site());
163    let retry_fn_ident = syn::Ident::new(
164        &generate_ident_with_prefix(&("retry_".to_owned() + &fn_name)),
165        proc_macro2::Span::call_site(),
166    );
167
168    let async_fn_name = &generate_ident_with_prefix(&("is_async_".to_owned() + fn_name.deref()));
169    let async_fn_ident = syn::Ident::new(
170        async_fn_name,
171        proc_macro2::Span::call_site(),
172    );
173    let (
174        _return_type,
175        _underlying_lift_fn_name,
176        fun_gen,
177        return_type_ident,
178        ret_gen,
179        underlying_lift_fn_name_ident,
180    ) = /*if (asyncFn)*/ {
181        let return_type = if async_fn {
182            "BoxedAsyncFn".to_owned() + arg_length.to_string().as_str()
183        } else {
184            "BoxedFn".to_owned() + arg_length.to_string().as_str()
185        };
186        let underlying_lift_fn_name = if async_fn {
187            "lift_async_fn".to_owned() + arg_length.to_string().as_str()
188        } else {
189            "lift_sync_fn".to_owned() + arg_length.to_string().as_str()
190        };
191        let gen_type_params = generate_generics_parameters((arg_length + 1) as u8);
192        let fun_arg_params = generate_generics_parameters((arg_length) as u8);
193        let return_type_param = generate_return_type_param((arg_length + 1) as u8);
194        let fun_gen = if async_fn {
195            let gen = format!("<'a, {gen_type_params} E1, F:Fn({fun_arg_params})->BoxFuture<'a,Result<{return_type_param}, E1>> + 'a + Send +Sync>", );            
196            syn::parse_str::<syn::Generics>(
197                gen.as_str()
198            ).ok()
199                .unwrap()
200        } else {
201            let gen  =format!("<'a, {gen_type_params} E1, F:Fn({fun_arg_params})->Result<{return_type_param}, E1> + Send +Sync + 'a>");            
202            syn::parse_str::<syn::Generics>(
203                gen.as_str()                
204            ).ok().unwrap()
205        };
206
207        let return_type_ident = syn::Ident::new(return_type.as_str(), proc_macro2::Span::call_site());
208        let underlying_lift_fn_name_ident =
209            syn::Ident::new(underlying_lift_fn_name.as_str(), proc_macro2::Span::call_site());
210        let ret_gen = syn::parse_str::<syn::Generics>(format!("<'a,{gen_type_params} E1>").as_str()).ok().unwrap();
211
212        (
213            return_type,
214            underlying_lift_fn_name,
215            fun_gen,
216            return_type_ident,
217            ret_gen,
218            underlying_lift_fn_name_ident,
219        )
220    };
221
222    match retry {
223        OptionalRetry::NoRetry => {
224            let function_mut_args = FunctionMutArgs {
225                args: mut_arg_tokens,
226            };
227            let tokens: proc_macro2::TokenStream = quote! {
228                use function_compose::*;
229
230                pub fn #lift_fn_ident #fun_gen(f: F)  -> #return_type_ident #ret_gen{
231                    #underlying_lift_fn_name_ident(f)
232                }
233
234                pub fn #async_fn_ident ()  -> bool{
235                    #async_fn
236                }
237
238                 pub fn #is_retry_fn_ident ()  -> bool{
239                         false
240                    }
241
242                /**
243                * It is only added to keep the compiler happy for non retryable functions
244                */
245                pub fn #retry_fn_ident #fn_gen ( #function_mut_args)  #fn_return_type {
246                    panic!("Function not to be called");
247                }
248            };
249            let mut token_stream: proc_macro::TokenStream = tokens.into();
250            token_stream.extend(item.into_iter());
251            //println!("{}", toekn_stream.to_string());
252            token_stream
253        }
254        SomeRetry(strategy) => {
255            let function_args = FunctionArgs { args: arg_tokens };
256
257            let function_mut_args = FunctionMutArgs {
258                args: mut_arg_tokens,
259            };
260
261            let mutable_args: Vec<_> = filter_mutable_args(&function_args);
262
263            let mutex_tokens: Vec<_> = convert_to_create_mutex_tokens(&mutable_args);
264
265            let mutex_unlock_tokens: Vec<_> = convert_to_mutex_unlock_tokens(mutable_args);
266
267            let deref_mut_tokens: Vec<_> = convert_to_deref_tokens(&function_args);
268
269            let strategy_expr = strategy.strategy;
270            let retry_tokens: proc_macro2::TokenStream = if async_fn {
271                quote! {
272
273                    pub fn #retry_fn_ident #fn_gen(#function_mut_args)  #fn_return_type {
274                        use function_compose::*;
275                        use retry::*;
276                        use tokio_retry::Retry as AsyncRetry;
277                        use tokio::sync::Mutex;
278                        use std::ops::{Deref, DerefMut};
279                        async{
280                            #( #mutex_tokens )*
281                            let result = AsyncRetry::spawn(#strategy_expr, || async{
282                                #( #mutex_unlock_tokens )*;
283                                let r = #fn_ident(#( #deref_mut_tokens )*);
284                                //OperationResult::from()
285                                r.await
286                            });
287
288                            let result = match result.await{
289                                    Ok(result) => Ok(result),
290                                    Err(e) => Err(e)
291                            };
292                            result
293                        }.boxed()
294                    }
295                }
296            } else {
297                quote! {
298
299                    pub fn #retry_fn_ident #fn_gen (#function_mut_args)  #fn_return_type {
300                        use function_compose::*;
301                        use retry::*;
302
303                        let result = retry(#strategy_expr, ||{
304                            let r:#return_type_without_token = #fn_ident(#function_args).into();
305                            r
306                        });
307                        match result{
308                            Ok(result) => Ok(result),
309                            Err(e) => Err(e.error)
310                        }
311                    }
312                }
313            };
314
315            /*println!("#############################");
316            println!("{}", retry_tokens);
317            println!("#############################");*/
318
319            let tokens: proc_macro2::TokenStream = quote! {
320
321                use function_compose::*;
322                pub fn #lift_fn_ident #fun_gen(f: F)  -> #return_type_ident #ret_gen{
323                    //#lift_retry_fn_ident(#retryFnIdent)
324                    #underlying_lift_fn_name_ident(f)
325                }
326
327                /*pub fn #lift_retry_fn_ident #fun_gen(f: F)  -> #return_type_ident #ret_gen{
328                    #underlying_lift_fn_name_ident(#retryFnIdent)
329                }*/
330
331                 pub fn #is_retry_fn_ident ()  -> bool{
332                     true
333                }
334
335
336                pub fn #async_fn_ident ()  -> bool{
337                    #async_fn
338                }
339            };
340            let retry_token_stream: TokenStream = retry_tokens.into();
341
342            let mut token_stream: proc_macro::TokenStream = tokens.into();
343
344            token_stream.extend(item.into_iter());
345            token_stream.extend(retry_token_stream.into_iter());
346            /*println!("{}", toekn_stream.to_string());*/
347            token_stream
348        }
349    }
350}
351
352fn filter_mutable_args<'a>(function_args: &'a FunctionArgs) -> Vec<&'a &'a FnArg> {
353    function_args
354        .args
355        .iter()
356        .filter(|fn_arg| {
357            match fn_arg {
358                FnArg::Receiver(_) => return false,
359                FnArg::Typed(pat) => {
360                    let ty = pat.ty.deref();
361                    match ty {
362                        Type::Reference(ty_ref) => return ty_ref.mutability.is_some(),
363
364                        _ => {
365                            return false;
366                        }
367                    }
368                }
369            }
370        })
371        .collect()
372}
373
374fn convert_to_create_mutex_tokens(mutable_args: &Vec<&&FnArg>) -> Vec<proc_macro2::TokenStream> {
375    mutable_args
376        .iter()
377        .map(|i| match i {
378            FnArg::Receiver(_pat_type) => {
379                panic!();
380            }
381            FnArg::Typed(pat_type) => {
382                let pat = &pat_type.pat;
383                quote! {
384                    let mut #pat =Mutex::new(#pat);
385                }
386            }
387        })
388        .collect()
389}
390
391fn convert_to_mutex_unlock_tokens(mutable_args: Vec<&&FnArg>) -> Vec<proc_macro2::TokenStream> {
392    mutable_args
393        .iter()
394        .map(|i| match i {
395            FnArg::Receiver(_pat_type) => {
396                panic!();
397            }
398            FnArg::Typed(pat_type) => {
399                let pat = &pat_type.pat;
400                quote! {
401                    let mut #pat = #pat.lock().await;
402
403                }
404            }
405        })
406        .collect()
407}
408
409fn convert_to_deref_tokens(function_args: &FunctionArgs) -> Vec<proc_macro2::TokenStream> {
410    function_args
411        .args
412        .iter()
413        .map(|i| {
414            match i {
415                FnArg::Receiver(_pat_type) => {
416                    return quote! {};
417                }
418                FnArg::Typed(pat_type) => {
419                    let pat = &pat_type.pat;
420                    let ty = pat_type.ty.deref();
421                    match ty {
422                        Type::Reference(ty_ref) => {
423                            if ty_ref.mutability.is_some() {
424                                return quote! {
425                                    #pat.deref_mut(),
426                                };
427                            } else {
428                                return quote! {
429                                    #pat,
430                                };
431                            }
432                        }
433
434                        _ => {
435                            return quote! {
436                                #pat,
437                            };
438                        }
439                    }
440                    /*quote!{
441                        let mut #pat = #pat.lock().await;
442
443                    }*/
444                }
445            }
446        })
447        .collect()
448}