fracture_macros/
lib.rs

1mod select;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{Attribute, Expr, ExprLit, ItemFn, Lit, Meta, ReturnType, Token, parse_macro_input, punctuated::Punctuated, token::Comma};
6
7struct MacroArgs {
8    duration: proc_macro2::TokenStream
9}
10
11impl Default for MacroArgs {
12    fn default() -> Self {
13        Self {
14            duration: quote! { ::std::time::Duration::from_secs(60) }
15        }
16    }
17}
18
19impl MacroArgs {
20    fn parse(args: Punctuated<Meta, Comma>) -> Result<Self, syn::Error> {
21        let mut macro_args = Self::default();
22
23        for meta in args {
24            match meta {
25                Meta::NameValue(nv) => {
26                    if nv.path.is_ident("duration") {
27                        if let Expr::Lit(ExprLit { lit: Lit::Str(lit_str), .. })= &nv.value {
28                            macro_args.duration = parse_duration(&lit_str.value())?;
29                        }
30                        else {
31                            return Err(syn::Error::new_spanned(&nv.value, "Expected duration as a string, e.g., duration = \"120s\""));
32                        }
33                    }
34                    else if nv.path.is_ident("flavor")
35                        || nv.path.is_ident("worker_threads")
36                        || nv.path.is_ident("crate")
37                        || nv.path.is_ident("max_blocking_threads")
38                        || nv.path.is_ident("thread_name")
39                        || nv.path.is_ident("thread_stack_size")
40                        || nv.path.is_ident("global_queue_interval")
41                        || nv.path.is_ident("event_interval") {
42                        continue;
43                    }
44                    else {
45                        return Err(syn::Error::new_spanned(nv.path, "Unknown argument. Did you mean 'duration'?"));
46                    }
47                }
48                Meta::Path(_) => {
49                    continue;
50                }
51                _ => {
52                    return Err(syn::Error::new_spanned(meta, "Unsupported attribute argument. Use key-value pairs, e.g., duration = \"120s\""));
53                }
54            }
55        }
56
57        Ok(macro_args)
58    }
59}
60
61fn parse_duration(s: &str) -> Result<proc_macro2::TokenStream, syn::Error> {
62    if let Some(ms) = s.strip_suffix("ms") {
63        if let Ok(ms) = ms.parse::<u64>() {
64            return Ok(quote! { ::std::time::Duration::from_millis(#ms) });
65        }
66    }
67    if let Some(m) = s.strip_suffix("m").or_else(|| s.strip_suffix("min")).or_else(|| s.strip_suffix("mins")) {
68        if let Ok(mins) = m.parse::<u64>() {
69            let secs = mins * 60;
70            return Ok(quote! { ::std::time::Duration::from_secs(#secs) });
71        }
72    }
73    if let Some(secs) = s.strip_suffix("s").or_else(|| s.strip_suffix("sec")).or_else(|| s.strip_suffix("secs")) {
74        if let Ok(secs) = secs.parse::<u64>() {
75            return Ok(quote! { ::std::time::Duration::from_secs(#secs) });
76        }
77    }
78    Err(syn::Error::new_spanned(s, "Failed to parse duration. Use 'ms' (for milliseconds), 'm', 'min', 'mins' (for minutes), 's', 'sec', 'secs' (for seconds)"))
79}
80
81#[proc_macro_attribute]
82pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
83    let input = parse_macro_input!(item as ItemFn);
84    let args = parse_macro_input!(attr with Punctuated<Meta, Token![,]>::parse_terminated);
85
86    let macro_args = match MacroArgs::parse(args) {
87        Ok(args) => args,
88        Err(e) => return e.to_compile_error().into()
89    };
90
91    let attrs = input.attrs;
92    let vis = input.vis;
93    let sig = input.sig;
94    let body = input.block;
95    let fn_name = &sig.ident;
96    let duration = macro_args.duration;
97
98    if sig.asyncness.is_none() {
99        return syn::Error::new_spanned(sig.fn_token, "#[fracture::test] only supports async functions").to_compile_error().into();
100    }
101
102    let expanded = quote! {
103        #[::core::prelude::v1::test]
104        #(#attrs)*
105        #vis fn #fn_name() {
106            ::fracture::chaos::init_from_env();
107
108            let runtime = ::fracture::runtime::Runtime::new();
109
110            runtime.block_on(async {
111                ::fracture::chaos::trace::clear_trace();
112                ::fracture::chaos::invariants::reset();
113
114                let checker_handle = ::fracture::task::spawn(async {
115                    loop {
116                        if !::fracture::chaos::invariants::check_all() {
117                            break;
118                        }
119
120                        ::fracture::time::sleep(::std::time::Duration::from_millis(100)).await;
121                    }
122                });
123
124                let test_body = async #body;
125                let test_result = ::fracture::time::timeout(#duration, test_body).await;
126
127                checker_handle.abort();
128
129                let violations = ::fracture::chaos::invariants::get_violations();
130                let bugs = ::fracture::chaos::trace::find_bugs();
131                let seed = ::fracture::chaos::get_seed();
132                let trace = ::fracture::chaos::trace::get_trace();
133
134                let has_failure = !violations.is_empty() || !bugs.is_empty() || test_result.is_err();
135
136                if has_failure {
137                    let report = ::fracture::chaos::visualization::generate_report(seed, violations, bugs, trace);
138                    let report_string = report.generate_report_string();
139                    
140                    if test_result.is_err() {
141                        panic!("\n\n{}\n\nFracture test timed out after {:?}\n\n", report_string, #duration);
142                    } else {
143                        panic!("\n\n{}\n\n", report_string);
144                    }
145                }
146            });
147        }
148    };
149
150    TokenStream::from(expanded)
151}
152
153#[proc_macro_attribute]
154pub fn main(attr: TokenStream, item: TokenStream) -> TokenStream {
155    let input = parse_macro_input!(item as ItemFn);
156
157    let args = parse_macro_input!(attr with Punctuated<Meta, Token![,]>::parse_terminated);
158
159    let macro_args = match MacroArgs::parse(args) {
160        Ok(args) => args,
161        Err(e) => return e.to_compile_error().into(),
162    };
163
164    let attrs = input.attrs;
165    let vis = input.vis;
166    let sig = input.sig;
167    let body = input.block;
168    let fn_name = &sig.ident;
169    let _duration = macro_args.duration;
170
171    let ret = match sig.output {
172        ReturnType::Default => quote! {},
173        ReturnType::Type(_, ty) => quote! { -> #ty }
174    };
175
176    let expanded = quote! {
177        #(#attrs)*
178        #vis fn #fn_name() #ret {
179            #[cfg(feature = "simulation")]
180            {
181                ::fracture::chaos::init_from_env();
182
183                let runtime = ::fracture::runtime::Runtime::new();
184                
185                runtime.block_on(async {
186                    #body
187                })
188            }
189
190            #[cfg(not(feature = "simulation"))]
191            {
192                ::tokio::runtime::Builder::new_multi_thread()
193                    .enable_all()
194                    .build()
195                    .expect("Failed to build async runtime")
196                    .block_on(async {
197                        #body
198                    })
199            }
200        }
201    };
202
203    TokenStream::from(expanded)
204}
205
206#[proc_macro]
207pub fn select(input: TokenStream) -> TokenStream {
208    select::select(input)
209}
210
211#[proc_macro]
212pub fn join(input: TokenStream) -> TokenStream {
213    select::join(input).into()
214}
215
216#[proc_macro]
217pub fn try_join(input: TokenStream) -> TokenStream {
218    select::try_join(input)
219}
220
221#[proc_macro]
222pub fn pin(input: TokenStream) -> TokenStream {
223    select::pin(input)
224}
225
226struct TaskLocalInput {
227    attrs: Vec<syn::Attribute>,
228    vis: syn::Visibility,
229    name: syn::Ident,
230    ty: syn::Type,
231    init: syn::Expr
232}
233
234impl syn::parse::Parse for TaskLocalInput {
235    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
236        let attrs = input.call(Attribute::parse_outer)?;
237        let vis = input.parse()?;
238        input.parse::<Token![static]>()?;
239        let name = input.parse()?;
240        input.parse::<Token![:]>()?;
241        let ty = input.parse()?;
242        input.parse::<Token![=]>()?;
243        let init = input.parse()?;
244
245        Ok(TaskLocalInput {
246            attrs,
247            vis,
248            name,
249            ty,
250            init
251        })
252    }
253}
254
255#[proc_macro]
256pub fn task_local(input: TokenStream) -> TokenStream {
257    let input = parse_macro_input!(input as TaskLocalInput);
258
259    let vis = &input.vis;
260    let name = &input.name;
261    let ty = &input.ty;
262    let init = &input.init;
263    let attrs = &input.attrs;
264
265    let expanded = quote! {
266        #(#attrs)*
267        #vis static #name: ::fracture::task::LocalKey<#ty> = {
268            thread_local! {
269                static INNER: ::std::cell::RefCell<Option<#ty>> = ::std::cell::RefCell::new(None);
270            }
271
272            ::fracture::task::LocalKey {
273                inner: &INNER,
274                init: || #init
275            }
276        };
277    };
278
279    TokenStream::from(expanded)
280}