fracture-macros 0.1.0

Proc macros for the Fracture chaos testing framework
Documentation
mod select;

use proc_macro::TokenStream;
use quote::quote;
use syn::{Attribute, Expr, ExprLit, ItemFn, Lit, Meta, ReturnType, Token, parse_macro_input, punctuated::Punctuated, token::Comma};

struct MacroArgs {
    duration: proc_macro2::TokenStream
}

impl Default for MacroArgs {
    fn default() -> Self {
        Self {
            duration: quote! { ::std::time::Duration::from_secs(60) }
        }
    }
}

impl MacroArgs {
    fn parse(args: Punctuated<Meta, Comma>) -> Result<Self, syn::Error> {
        let mut macro_args = Self::default();

        for meta in args {
            match meta {
                Meta::NameValue(nv) => {
                    if nv.path.is_ident("duration") {
                        if let Expr::Lit(ExprLit { lit: Lit::Str(lit_str), .. })= &nv.value {
                            macro_args.duration = parse_duration(&lit_str.value())?;
                        }
                        else {
                            return Err(syn::Error::new_spanned(&nv.value, "Expected duration as a string, e.g., duration = \"120s\""));
                        }
                    }
                    else if nv.path.is_ident("flavor")
                        || nv.path.is_ident("worker_threads")
                        || nv.path.is_ident("crate")
                        || nv.path.is_ident("max_blocking_threads")
                        || nv.path.is_ident("thread_name")
                        || nv.path.is_ident("thread_stack_size")
                        || nv.path.is_ident("global_queue_interval")
                        || nv.path.is_ident("event_interval") {
                        continue;
                    }
                    else {
                        return Err(syn::Error::new_spanned(nv.path, "Unknown argument. Did you mean 'duration'?"));
                    }
                }
                Meta::Path(_) => {
                    continue;
                }
                _ => {
                    return Err(syn::Error::new_spanned(meta, "Unsupported attribute argument. Use key-value pairs, e.g., duration = \"120s\""));
                }
            }
        }

        Ok(macro_args)
    }
}

fn parse_duration(s: &str) -> Result<proc_macro2::TokenStream, syn::Error> {
    if let Some(ms) = s.strip_suffix("ms") {
        if let Ok(ms) = ms.parse::<u64>() {
            return Ok(quote! { ::std::time::Duration::from_millis(#ms) });
        }
    }
    if let Some(m) = s.strip_suffix("m").or_else(|| s.strip_suffix("min")).or_else(|| s.strip_suffix("mins")) {
        if let Ok(mins) = m.parse::<u64>() {
            let secs = mins * 60;
            return Ok(quote! { ::std::time::Duration::from_secs(#secs) });
        }
    }
    if let Some(secs) = s.strip_suffix("s").or_else(|| s.strip_suffix("sec")).or_else(|| s.strip_suffix("secs")) {
        if let Ok(secs) = secs.parse::<u64>() {
            return Ok(quote! { ::std::time::Duration::from_secs(#secs) });
        }
    }
    Err(syn::Error::new_spanned(s, "Failed to parse duration. Use 'ms' (for milliseconds), 'm', 'min', 'mins' (for minutes), 's', 'sec', 'secs' (for seconds)"))
}

#[proc_macro_attribute]
pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
    let input = parse_macro_input!(item as ItemFn);
    let args = parse_macro_input!(attr with Punctuated<Meta, Token![,]>::parse_terminated);

    let macro_args = match MacroArgs::parse(args) {
        Ok(args) => args,
        Err(e) => return e.to_compile_error().into()
    };

    let attrs = input.attrs;
    let vis = input.vis;
    let sig = input.sig;
    let body = input.block;
    let fn_name = &sig.ident;
    let duration = macro_args.duration;

    if sig.asyncness.is_none() {
        return syn::Error::new_spanned(sig.fn_token, "#[fracture::test] only supports async functions").to_compile_error().into();
    }

    let expanded = quote! {
        #[::core::prelude::v1::test]
        #(#attrs)*
        #vis fn #fn_name() {
            ::fracture::chaos::init_from_env();

            let runtime = ::fracture::runtime::Runtime::new();

            runtime.block_on(async {
                ::fracture::chaos::trace::clear_trace();
                ::fracture::chaos::invariants::reset();

                let checker_handle = ::fracture::task::spawn(async {
                    loop {
                        if !::fracture::chaos::invariants::check_all() {
                            break;
                        }

                        ::fracture::time::sleep(::std::time::Duration::from_millis(100)).await;
                    }
                });

                let test_body = async #body;
                let test_result = ::fracture::time::timeout(#duration, test_body).await;

                checker_handle.abort();

                let violations = ::fracture::chaos::invariants::get_violations();
                let bugs = ::fracture::chaos::trace::find_bugs();
                let seed = ::fracture::chaos::get_seed();
                let trace = ::fracture::chaos::trace::get_trace();

                let has_failure = !violations.is_empty() || !bugs.is_empty() || test_result.is_err();

                if has_failure {
                    let report = ::fracture::chaos::visualization::generate_report(seed, violations, bugs, trace);
                    let report_string = report.generate_report_string();
                    
                    if test_result.is_err() {
                        panic!("\n\n{}\n\nFracture test timed out after {:?}\n\n", report_string, #duration);
                    } else {
                        panic!("\n\n{}\n\n", report_string);
                    }
                }
            });
        }
    };

    TokenStream::from(expanded)
}

#[proc_macro_attribute]
pub fn main(attr: TokenStream, item: TokenStream) -> TokenStream {
    let input = parse_macro_input!(item as ItemFn);

    let args = parse_macro_input!(attr with Punctuated<Meta, Token![,]>::parse_terminated);

    let macro_args = match MacroArgs::parse(args) {
        Ok(args) => args,
        Err(e) => return e.to_compile_error().into(),
    };

    let attrs = input.attrs;
    let vis = input.vis;
    let sig = input.sig;
    let body = input.block;
    let fn_name = &sig.ident;
    let _duration = macro_args.duration;

    let ret = match sig.output {
        ReturnType::Default => quote! {},
        ReturnType::Type(_, ty) => quote! { -> #ty }
    };

    let expanded = quote! {
        #(#attrs)*
        #vis fn #fn_name() #ret {
            #[cfg(feature = "simulation")]
            {
                ::fracture::chaos::init_from_env();

                let runtime = ::fracture::runtime::Runtime::new();
                
                runtime.block_on(async {
                    #body
                })
            }

            #[cfg(not(feature = "simulation"))]
            {
                ::tokio::runtime::Builder::new_multi_thread()
                    .enable_all()
                    .build()
                    .expect("Failed to build async runtime")
                    .block_on(async {
                        #body
                    })
            }
        }
    };

    TokenStream::from(expanded)
}

#[proc_macro]
pub fn select(input: TokenStream) -> TokenStream {
    select::select(input)
}

#[proc_macro]
pub fn join(input: TokenStream) -> TokenStream {
    select::join(input).into()
}

#[proc_macro]
pub fn try_join(input: TokenStream) -> TokenStream {
    select::try_join(input)
}

#[proc_macro]
pub fn pin(input: TokenStream) -> TokenStream {
    select::pin(input)
}

struct TaskLocalInput {
    attrs: Vec<syn::Attribute>,
    vis: syn::Visibility,
    name: syn::Ident,
    ty: syn::Type,
    init: syn::Expr
}

impl syn::parse::Parse for TaskLocalInput {
    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
        let attrs = input.call(Attribute::parse_outer)?;
        let vis = input.parse()?;
        input.parse::<Token![static]>()?;
        let name = input.parse()?;
        input.parse::<Token![:]>()?;
        let ty = input.parse()?;
        input.parse::<Token![=]>()?;
        let init = input.parse()?;

        Ok(TaskLocalInput {
            attrs,
            vis,
            name,
            ty,
            init
        })
    }
}

#[proc_macro]
pub fn task_local(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as TaskLocalInput);

    let vis = &input.vis;
    let name = &input.name;
    let ty = &input.ty;
    let init = &input.init;
    let attrs = &input.attrs;

    let expanded = quote! {
        #(#attrs)*
        #vis static #name: ::fracture::task::LocalKey<#ty> = {
            thread_local! {
                static INNER: ::std::cell::RefCell<Option<#ty>> = ::std::cell::RefCell::new(None);
            }

            ::fracture::task::LocalKey {
                inner: &INNER,
                init: || #init
            }
        };
    };

    TokenStream::from(expanded)
}