1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, spanned::Spanned, Error, ItemFn, LitStr};
4
5#[proc_macro_attribute]
6pub fn global_transaction(
7 attr: TokenStream,
8 func: TokenStream,
9) -> TokenStream {
10 let transaction_name = parse_macro_input!(attr as LitStr);
11 let mut input_fn = parse_macro_input!(func as ItemFn);
12
13 if input_fn.sig.asyncness.is_none() {
14 return Error::new(
15 input_fn.sig.span(),
16 "global_transaction can only be applied to async functions",
17 ).to_compile_error().into();
18 }
19
20 let _is_result = if let syn::ReturnType::Type(_, ref ty) = input_fn.sig.output {
21 if let syn::Type::Path(type_path) = &**ty {
22 type_path
23 .path
24 .segments
25 .last()
26 .map(|seg| seg.ident == "Result")
27 .unwrap_or(false)
28 } else {
29 false
30 }
31 } else {
32 false
33 };
34
35 let original_block = &input_fn.block;
36 let new_block = quote! {
37 {
38 use rseata::core::TransactionManager;
39 use rseata::FutureExt;
40 use std::panic::AssertUnwindSafe;
41 use rseata::RSEATA_TM;
42 use std::sync::Arc;
43 use rseata::RSEATA_CLIENT_SESSION;
44 use rseata::core::{ClientSession};
45 let session = Arc::new(ClientSession::new(String::from(#transaction_name)));
46 let session_clone = session.clone();
47
48 let result = RSEATA_CLIENT_SESSION.scope(
49 session,
50 AssertUnwindSafe(async {
51 { #original_block }
52 })
53 .catch_unwind()
54 .map(|res| res.unwrap_or_else(|_| Err(anyhow::anyhow!("Panic occurred in transaction scope")))),
55 ).await;
56
57 let xid = session_clone.get_xid();
58 if let Some (xid) = xid {
59 match result {
60 Ok(data) => {
61 RSEATA_TM.commit(xid.clone()).await?;
62 Ok(data)
63 }
64 Err(err) => {
65 RSEATA_TM.rollback(xid.clone()).await?;
66 Err(err)
67 }
68 }
69 }else {
70 result
71 }
72 }
73 };
74 input_fn.block = syn::parse2(new_block).unwrap();
75 TokenStream::from(quote! { #input_fn })
76}