1mod select;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{Attribute, Expr, ExprLit, ItemFn, Lit, Meta, ReturnType, Token, parse_macro_input, punctuated::Punctuated, token::Comma};
6
7struct MacroArgs {
8 duration: proc_macro2::TokenStream
9}
10
11impl Default for MacroArgs {
12 fn default() -> Self {
13 Self {
14 duration: quote! { ::std::time::Duration::from_secs(60) }
15 }
16 }
17}
18
19impl MacroArgs {
20 fn parse(args: Punctuated<Meta, Comma>) -> Result<Self, syn::Error> {
21 let mut macro_args = Self::default();
22
23 for meta in args {
24 match meta {
25 Meta::NameValue(nv) => {
26 if nv.path.is_ident("duration") {
27 if let Expr::Lit(ExprLit { lit: Lit::Str(lit_str), .. })= &nv.value {
28 macro_args.duration = parse_duration(&lit_str.value())?;
29 }
30 else {
31 return Err(syn::Error::new_spanned(&nv.value, "Expected duration as a string, e.g., duration = \"120s\""));
32 }
33 }
34 else if nv.path.is_ident("flavor")
35 || nv.path.is_ident("worker_threads")
36 || nv.path.is_ident("crate")
37 || nv.path.is_ident("max_blocking_threads")
38 || nv.path.is_ident("thread_name")
39 || nv.path.is_ident("thread_stack_size")
40 || nv.path.is_ident("global_queue_interval")
41 || nv.path.is_ident("event_interval") {
42 continue;
43 }
44 else {
45 return Err(syn::Error::new_spanned(nv.path, "Unknown argument. Did you mean 'duration'?"));
46 }
47 }
48 Meta::Path(_) => {
49 continue;
50 }
51 _ => {
52 return Err(syn::Error::new_spanned(meta, "Unsupported attribute argument. Use key-value pairs, e.g., duration = \"120s\""));
53 }
54 }
55 }
56
57 Ok(macro_args)
58 }
59}
60
61fn parse_duration(s: &str) -> Result<proc_macro2::TokenStream, syn::Error> {
62 if let Some(ms) = s.strip_suffix("ms") {
63 if let Ok(ms) = ms.parse::<u64>() {
64 return Ok(quote! { ::std::time::Duration::from_millis(#ms) });
65 }
66 }
67 if let Some(m) = s.strip_suffix("m").or_else(|| s.strip_suffix("min")).or_else(|| s.strip_suffix("mins")) {
68 if let Ok(mins) = m.parse::<u64>() {
69 let secs = mins * 60;
70 return Ok(quote! { ::std::time::Duration::from_secs(#secs) });
71 }
72 }
73 if let Some(secs) = s.strip_suffix("s").or_else(|| s.strip_suffix("sec")).or_else(|| s.strip_suffix("secs")) {
74 if let Ok(secs) = secs.parse::<u64>() {
75 return Ok(quote! { ::std::time::Duration::from_secs(#secs) });
76 }
77 }
78 Err(syn::Error::new_spanned(s, "Failed to parse duration. Use 'ms' (for milliseconds), 'm', 'min', 'mins' (for minutes), 's', 'sec', 'secs' (for seconds)"))
79}
80
81#[proc_macro_attribute]
82pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
83 let input = parse_macro_input!(item as ItemFn);
84 let args = parse_macro_input!(attr with Punctuated<Meta, Token![,]>::parse_terminated);
85
86 let macro_args = match MacroArgs::parse(args) {
87 Ok(args) => args,
88 Err(e) => return e.to_compile_error().into()
89 };
90
91 let attrs = input.attrs;
92 let vis = input.vis;
93 let sig = input.sig;
94 let body = input.block;
95 let fn_name = &sig.ident;
96 let duration = macro_args.duration;
97
98 if sig.asyncness.is_none() {
99 return syn::Error::new_spanned(sig.fn_token, "#[fracture::test] only supports async functions").to_compile_error().into();
100 }
101
102 let expanded = quote! {
103 #[::core::prelude::v1::test]
104 #(#attrs)*
105 #vis fn #fn_name() {
106 ::fracture::chaos::init_from_env();
107
108 let runtime = ::fracture::runtime::Runtime::new();
109
110 runtime.block_on(async {
111 ::fracture::chaos::trace::clear_trace();
112 ::fracture::chaos::invariants::reset();
113
114 let checker_handle = ::fracture::task::spawn(async {
115 loop {
116 if !::fracture::chaos::invariants::check_all() {
117 break;
118 }
119
120 ::fracture::time::sleep(::std::time::Duration::from_millis(100)).await;
121 }
122 });
123
124 let test_body = async #body;
125 let test_result = ::fracture::time::timeout(#duration, test_body).await;
126
127 checker_handle.abort();
128
129 let violations = ::fracture::chaos::invariants::get_violations();
130 let bugs = ::fracture::chaos::trace::find_bugs();
131 let seed = ::fracture::chaos::get_seed();
132 let trace = ::fracture::chaos::trace::get_trace();
133
134 let has_failure = !violations.is_empty() || !bugs.is_empty() || test_result.is_err();
135
136 if has_failure {
137 let report = ::fracture::chaos::visualization::generate_report(seed, violations, bugs, trace);
138 let report_string = report.generate_report_string();
139
140 if test_result.is_err() {
141 panic!("\n\n{}\n\nFracture test timed out after {:?}\n\n", report_string, #duration);
142 } else {
143 panic!("\n\n{}\n\n", report_string);
144 }
145 }
146 });
147 }
148 };
149
150 TokenStream::from(expanded)
151}
152
153#[proc_macro_attribute]
154pub fn main(attr: TokenStream, item: TokenStream) -> TokenStream {
155 let input = parse_macro_input!(item as ItemFn);
156
157 let args = parse_macro_input!(attr with Punctuated<Meta, Token![,]>::parse_terminated);
158
159 let macro_args = match MacroArgs::parse(args) {
160 Ok(args) => args,
161 Err(e) => return e.to_compile_error().into(),
162 };
163
164 let attrs = input.attrs;
165 let vis = input.vis;
166 let sig = input.sig;
167 let body = input.block;
168 let fn_name = &sig.ident;
169 let _duration = macro_args.duration;
170
171 let ret = match sig.output {
172 ReturnType::Default => quote! {},
173 ReturnType::Type(_, ty) => quote! { -> #ty }
174 };
175
176 let expanded = quote! {
177 #(#attrs)*
178 #vis fn #fn_name() #ret {
179 #[cfg(feature = "simulation")]
180 {
181 ::fracture::chaos::init_from_env();
182
183 let runtime = ::fracture::runtime::Runtime::new();
184
185 runtime.block_on(async {
186 #body
187 })
188 }
189
190 #[cfg(not(feature = "simulation"))]
191 {
192 ::tokio::runtime::Builder::new_multi_thread()
193 .enable_all()
194 .build()
195 .expect("Failed to build async runtime")
196 .block_on(async {
197 #body
198 })
199 }
200 }
201 };
202
203 TokenStream::from(expanded)
204}
205
206#[proc_macro]
207pub fn select(input: TokenStream) -> TokenStream {
208 select::select(input)
209}
210
211#[proc_macro]
212pub fn join(input: TokenStream) -> TokenStream {
213 select::join(input).into()
214}
215
216#[proc_macro]
217pub fn try_join(input: TokenStream) -> TokenStream {
218 select::try_join(input)
219}
220
221#[proc_macro]
222pub fn pin(input: TokenStream) -> TokenStream {
223 select::pin(input)
224}
225
226struct TaskLocalInput {
227 attrs: Vec<syn::Attribute>,
228 vis: syn::Visibility,
229 name: syn::Ident,
230 ty: syn::Type,
231 init: syn::Expr
232}
233
234impl syn::parse::Parse for TaskLocalInput {
235 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
236 let attrs = input.call(Attribute::parse_outer)?;
237 let vis = input.parse()?;
238 input.parse::<Token![static]>()?;
239 let name = input.parse()?;
240 input.parse::<Token![:]>()?;
241 let ty = input.parse()?;
242 input.parse::<Token![=]>()?;
243 let init = input.parse()?;
244
245 Ok(TaskLocalInput {
246 attrs,
247 vis,
248 name,
249 ty,
250 init
251 })
252 }
253}
254
255#[proc_macro]
256pub fn task_local(input: TokenStream) -> TokenStream {
257 let input = parse_macro_input!(input as TaskLocalInput);
258
259 let vis = &input.vis;
260 let name = &input.name;
261 let ty = &input.ty;
262 let init = &input.init;
263 let attrs = &input.attrs;
264
265 let expanded = quote! {
266 #(#attrs)*
267 #vis static #name: ::fracture::task::LocalKey<#ty> = {
268 thread_local! {
269 static INNER: ::std::cell::RefCell<Option<#ty>> = ::std::cell::RefCell::new(None);
270 }
271
272 ::fracture::task::LocalKey {
273 inner: &INNER,
274 init: || #init
275 }
276 };
277 };
278
279 TokenStream::from(expanded)
280}