Skip to main content

rseata_micro/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, spanned::Spanned, Error, ItemFn, LitStr};
4
5#[proc_macro_attribute]
6pub fn global_transaction(
7    attr: TokenStream,
8    func: TokenStream,
9) -> TokenStream {
10    let transaction_name = parse_macro_input!(attr as LitStr);
11    let mut input_fn = parse_macro_input!(func as ItemFn);
12
13    if input_fn.sig.asyncness.is_none() {
14        return Error::new(
15            input_fn.sig.span(),
16            "global_transaction can only be applied to async functions",
17        ).to_compile_error().into();
18    }
19
20    let _is_result = if let syn::ReturnType::Type(_, ref ty) = input_fn.sig.output {
21        if let syn::Type::Path(type_path) = &**ty {
22            type_path
23                .path
24                .segments
25                .last()
26                .map(|seg| seg.ident == "Result")
27                .unwrap_or(false)
28        } else {
29            false
30        }
31    } else {
32        false
33    };
34
35    let original_block = &input_fn.block;
36    let new_block = quote! {
37       {
38           use rseata::core::TransactionManager;
39           use rseata::FutureExt;
40           use std::panic::AssertUnwindSafe;
41           use rseata::RSEATA_TM;
42           use std::sync::Arc;
43           use rseata::RSEATA_CLIENT_SESSION;
44           use rseata::core::{ClientSession};
45            let session = Arc::new(ClientSession::new(String::from(#transaction_name)));
46            let session_clone = session.clone();
47           
48            let result = RSEATA_CLIENT_SESSION.scope(
49                session,
50                AssertUnwindSafe(async {
51                    { #original_block }
52                })
53                .catch_unwind()
54                .map(|res| res.unwrap_or_else(|_| Err(anyhow::anyhow!("Panic occurred in transaction scope")))),
55            ).await;
56           
57            let xid = session_clone.get_xid();
58            if let Some (xid) = xid {
59                match result {
60                    Ok(data) => {
61                        RSEATA_TM.commit(xid.clone()).await?;
62                        Ok(data)
63                    }
64                    Err(err) => {
65                        RSEATA_TM.rollback(xid.clone()).await?;
66                        Err(err)
67                    }
68                }
69            }else {
70                result
71            }
72      }
73    };
74    input_fn.block = syn::parse2(new_block).unwrap();
75    TokenStream::from(quote! { #input_fn })
76}