Skip to main content

go_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{ItemFn, Type, parse_macro_input};
4
5mod select_parse;
6use select_parse::parse_select;
7
8/// Mark the main function to use the GoRust runtime
9///
10/// # Example
11/// ```rust
12/// use gorust::{runtime, go, make_chan};
13///
14/// #[runtime]
15/// fn main() {
16///     go(|| {
17///         debug!("Hello from goroutine!");
18///     });
19///     
20///     let ch = make_chan!(i32, 10);
21///     go(move || {
22///         ch.send(42).unwrap();
23///     });
24///     
25///     debug!("Received: {}", ch.recv().unwrap());
26/// }
27/// ```
28#[proc_macro_attribute]
29pub fn runtime(_args: TokenStream, input: TokenStream) -> TokenStream {
30    let input_fn = parse_macro_input!(input as ItemFn);
31    let _fn_name = &input_fn.sig.ident;
32    let fn_block = &input_fn.block;
33    let fn_attrs = &input_fn.attrs;
34    let fn_vis = &input_fn.vis;
35    let fn_sig = &input_fn.sig;
36
37    // 生成包装后的 main 函数
38    let expanded = quote! {
39        #(#fn_attrs)*
40        #fn_vis #fn_sig {
41            ::gorust::Runtime::init();
42
43            let original_hook = std::panic::take_hook();
44            std::panic::set_hook(Box::new(move |panic_info| {
45                println!("[Goroutine panic] {}", panic_info);
46                original_hook(panic_info);
47            }));
48
49            let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
50                #fn_block
51            }));
52
53            ::gorust::Runtime::wait_for_all();
54            ::gorust::Runtime::wait_and_shutdown();
55            ::gorust::Runtime::shutdown();
56
57            match result {
58                Ok(ret) => ret,
59                Err(err) => std::panic::resume_unwind(err),
60            }
61        }
62    };
63
64    TokenStream::from(expanded)
65}
66
67/// make_chan! macro - Create a channel
68#[proc_macro]
69pub fn make_chan(input: TokenStream) -> TokenStream {
70    let parsed_input = input.to_string();
71    let parts: Vec<&str> = parsed_input.trim_end_matches(')').split(',').collect();
72
73    if parts.len() == 1 {
74        // 无缓冲 channel
75        let ty = parts[0].trim();
76        let ty_parsed: Type = syn::parse_str(ty).expect("Invalid type in make_chan!");
77        let expanded = quote! {
78            ::gorust::Channel::<#ty_parsed>::new(0)
79        };
80        TokenStream::from(expanded)
81    } else {
82        // 有缓冲 channel
83        let ty = parts[0].trim();
84        let ty_parsed: Type = syn::parse_str(ty).expect("Invalid type in make_chan!");
85        let size = parts[1].trim();
86        // 尝试解析容量为表达式
87        let size_expr: proc_macro2::TokenStream =
88            size.parse().expect("Invalid capacity in make_chan!");
89        let expanded = quote! {
90            ::gorust::Channel::<#ty_parsed>::new(#size_expr)
91        };
92        TokenStream::from(expanded)
93    }
94}
95
96/// select! macro - Multiplex channel operations
97#[proc_macro]
98pub fn select(input: TokenStream) -> TokenStream {
99    let input_str = input.to_string();
100    match parse_select(input_str) {
101        Ok(expanded) => TokenStream::from(expanded),
102        Err(err) => {
103            let error = quote! {
104                compile_error!(#err)
105            };
106            TokenStream::from(error)
107        }
108    }
109}