1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
extern crate proc_macro;
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, ItemFn};

#[proc_macro_attribute]
pub fn pogo(_attr: TokenStream, item: TokenStream) -> TokenStream {
    let input = parse_macro_input!(item as ItemFn);

    let input_src_string = quote!(#input).to_string();

    let function_name = input.sig.ident;
    let function_inputs = input.sig.inputs;
    let return_type = input.sig.output;
    let function_body = input.block;
    let native_func_name = quote::format_ident!("__pogo_native_{}", function_name);

    let native_function = quote! {
        pub(crate) fn #native_func_name(#function_inputs) #return_type {
            #function_body
        }
    };

    let mut type_args: syn::punctuated::Punctuated<Box<syn::Type>, syn::token::Comma> =
        syn::punctuated::Punctuated::new();

    for arg in function_inputs.iter() {
        match arg {
            syn::FnArg::Receiver(_) => panic!("Not supported"),
            syn::FnArg::Typed(pat_type) => {
                type_args.push_value(pat_type.ty.clone());
            }
        }
    }

    let mut arg_names: syn::punctuated::Punctuated<syn::Ident, syn::token::Comma> =
        syn::punctuated::Punctuated::new();

    for arg in function_inputs.iter() {
        match arg {
            syn::FnArg::Receiver(_) => panic!("Not supported"),
            syn::FnArg::Typed(pat_type) => match &pat_type.pat.as_ref() {
                syn::Pat::Ident(ident) => arg_names.push_value(ident.ident.clone()),
                _ => panic!("Not supported"),
            },
        }
    }

    let vis = input.vis;
    let group_func_name = quote::format_ident!("{}_with_group", function_name);
    let ctx_name = quote::format_ident!("__pogo_ctx_{}", function_name);
    let info_name = quote::format_ident!("__pogo_info_{}", function_name);

    let str_func_name = function_name.to_string();

    TokenStream::from(quote! {
        #native_function

        #[allow(non_upper_case_globals)]
        static #ctx_name: pogo::ContextCell = pogo::ContextCell::new();

        #[allow(non_upper_case_globals)]
        static #info_name: pogo::PogoFuncDefinition = pogo::PogoFuncDefinition {
            edition: pogo::Edition::Rust2018,
            name: #str_func_name,
            src: #input_src_string,
        };

        #vis fn #function_name(#function_inputs) #return_type {
            #group_func_name::<pogo::Global>(#arg_names)
        }

        #vis fn #group_func_name<Grp: pogo::PogoGroup>(#function_inputs) #return_type {
            match #ctx_name.get() {
                Some(ctx) if Grp::USE_PGO => {
                    match ctx.groups.get(Grp::NAME) {
                        Some(group) => {
                            match &group.pgo_state {
                                pogo::PgoState::Uninitialized | pogo::PgoState::CompilationFailed => {
                                    #native_func_name(#arg_names)
                                }
                                pogo::PgoState::GatheringData(lib) => {
                                    if group.pgo_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst) >= Grp::PGO_EXEC_COUNT
                                    {
                                        pogo::submit_optimization_request(ctx, Grp::NAME);
                                    }

                                    unsafe {
                                        let func: libloading::Symbol<unsafe extern fn(#type_args) #return_type> = lib.get(ctx.info.name.as_bytes()).expect("Run-time compiled shared object didn't contain expected function");
                                        func(#arg_names)
                                    }
                                }
                                pogo::PgoState::Compiling(lib) | pogo::PgoState::Optimized(lib) => unsafe {
                                    let func: libloading::Symbol<unsafe extern fn(#type_args) #return_type> = lib.get(ctx.info.name.as_bytes()).expect("Run-time compiled shared object didn't contain expected function");
                                    func(#arg_names)
                                },
                            }
                        }
                        None => {
                            ctx.groups.upsert(
                                Grp::NAME,
                                || {
                                    pogo::GroupState {
                                        pgo_state: pogo::PgoState::Uninitialized,
                                        pgo_count: std::sync::atomic::AtomicUsize::new(0),
                                    }
                                },
                                |_| {
                                    // The value already existed by the time we got to this branch
                                    // so don't touch it, someone should have already initialized it
                                },
                            );
                            // Execute the unoptimized non-tracking version for now
                            #native_func_name(#arg_names)
                        }
                    }
                },
                _ => #native_func_name(#arg_names),
            }
        }
    })
}