batched_derive 0.2.11

rust macro util for batching expensive operations
Documentation
use std::collections::BTreeMap;

use proc_macro2::TokenStream;
use quote::ToTokens;
use syn::{
    FnArg, GenericArgument, ItemFn, Meta, Pat, PathArguments, ReturnType, Token, Type,
    parse::Parser, punctuated::Punctuated,
};

use crate::utils::expr_to_u64;

#[derive(Debug)]
pub struct Function {
    pub macros: Vec<TokenStream>,
    pub identifier: String,
    pub visibility: TokenStream,
    pub inner: TokenStream,
    pub batched_arg: TokenStream,
    pub batched_arg_name: String,
    pub batched_arg_type: TokenStream,
    pub returned: FunctionResult,
}

#[derive(Debug)]
pub struct FunctionResult {
    pub result_type: FunctionResultType,
    pub tokens: TokenStream,
}

#[derive(Debug)]
pub enum FunctionResultType {
    Raw(TokenStream),
    VectorRaw(TokenStream),
    Result(Box<FunctionResult>, TokenStream, Option<TokenStream>)
}

fn inner_shared_error(_type: &Type) -> Option<TokenStream> {
    let type_path = match _type {
        Type::Path(path) => path,
        _ => unimplemented!(),
    };
    let path = type_path.path.segments.last().unwrap();
    if path.ident != "SharedError" {
        return None;
    }

    match &path.arguments {
        PathArguments::AngleBracketed(path_args) => {
            Some(path_args.args.clone().into_token_stream())
        }
        _ => unimplemented!(),
    }
}

impl Function {
    pub fn parse(tokens: TokenStream) -> Self {
        let function: ItemFn = syn::parse2(tokens).expect("invalid function");

        let macros = function
            .attrs
            .into_iter()
            .map(|attr| attr.into_token_stream())
            .collect();

        let visibility = function.vis.into_token_stream();
        let identifier = function.sig.ident.to_string();
        let args = function.sig.inputs;
        let inner = function.block.to_token_stream();

        fn parsed_returned(_type: &Type) -> FunctionResult {
            let tokens = _type.clone().into_token_stream();
            let result_type = match _type {
                Type::Path(type_path) => {
                    let path = type_path.path.segments.first().unwrap();

                    if path.ident == "Vec" {
                        let inner = match &path.arguments {
                            PathArguments::AngleBracketed(b) => b,
                            _ => unimplemented!(),
                        };
                        let inner = inner.args.to_token_stream();
                        FunctionResultType::VectorRaw(inner)
                    } else if path.ident == "Result" {
                        let inner = match &path.arguments {
                            PathArguments::AngleBracketed(b) => b,
                            _ => unimplemented!(),
                        };

                        let output = inner.args.get(0).unwrap();
                        let error = inner.args.get(1).unwrap();
                        let error = match error {
                            GenericArgument::Type(error) => error,
                            _ => unimplemented!(),
                        };

                        let output = match output {
                            GenericArgument::Type(_type) => parsed_returned(_type),
                            _ => unimplemented!(),
                        };
                        let inner_shared_error = inner_shared_error(error);
                        let error = error.into_token_stream();

                        FunctionResultType::Result(Box::new(output), error, inner_shared_error)
                    } else {
                        FunctionResultType::Raw(_type.into_token_stream())
                    }
                }
                _ => FunctionResultType::Raw(_type.into_token_stream()),
            };

            FunctionResult {
                tokens,
                result_type,
            }
        }
        let returned = match function.sig.output {
            ReturnType::Default => FunctionResult {
                tokens: syn::parse_str("()").unwrap(),
                result_type: FunctionResultType::Raw(syn::parse_str("()").unwrap()),
            },
            ReturnType::Type(_, _type) => parsed_returned(&_type),
        };

        let mut batched_arg: Option<TokenStream> = None;
        let mut batched_arg_name: Option<String> = None;
        let mut batched_arg_type: Option<TokenStream> = None;

        for arg in args {
            if let FnArg::Receiver(_) = arg {
                panic!("self reference functions are not supported")
            } else if let FnArg::Typed(arg) = arg {
                if batched_arg_type.is_some() {
                    panic!("function may only contain a single argument")
                }

                batched_arg = Some(arg.to_token_stream());
                batched_arg_name = match &*arg.pat {
                    Pat::Ident(pat_ident) => Some(pat_ident.ident.to_string()),
                    _ => panic!("unsupport argument name"),
                };

                if let syn::Type::Path(type_path) = &*arg.ty {
                    let segment = type_path.path.segments.last().unwrap();
                    if segment.ident == "Vec" {
                        let vec_type = match &segment.arguments {
                            PathArguments::AngleBracketed(a) => a,
                            _ => unreachable!(),
                        };
                        let vec_type = vec_type.args.first().unwrap();
                        batched_arg_type = Some(vec_type.to_token_stream());
                        continue;
                    }
                }

                panic!("function argument must be Vec<T>")
            }
        }

        if batched_arg_type.is_none() {
            panic!("function must contain a single argument")
        }

        let batched_arg = batched_arg.unwrap();
        let batched_arg_name = batched_arg_name.unwrap();
        let batched_arg_type = batched_arg_type.unwrap();

        Self {
            macros,
            identifier,
            visibility,
            batched_arg,
            batched_arg_name,
            batched_arg_type,
            returned,
            inner,
        }
    }
}

#[derive(Debug)]
pub struct Attributes {
    pub limit: Option<usize>,
    pub concurrent_limit: Option<usize>,
    pub asynchronous: bool,
    pub passthrough: bool,
    pub default_window: u64,
    pub windows: BTreeMap<u64, u64>,
}

impl Attributes {
    pub fn parse(tokens: TokenStream) -> Self {
        let mut limit: Option<usize> = None;
        let mut concurrent_limit: Option<usize> = None;
        let mut asynchronous = false;
        let mut passthrough = false;
        let mut default_window: Option<u64> = None;
        let mut windows = BTreeMap::new();

        static WINDOW_ATTR: &str = "window";
        static LIMIT_ATTR: &str = "limit";
        static CONCURRENT_LIMIT_ATTR: &str = "concurrent";
        static ASYNCHRONOUS_ATTR: &str = "asynchronous";
        static PASSTHROUGH_ATTR: &str = "passthrough";

        let parser = Punctuated::<Meta, Token![,]>::parse_separated_nonempty;
        let attributes = parser.parse(tokens.into()).unwrap();
        let attributes = attributes.into_iter().collect::<Vec<Meta>>();

        for attr in &attributes {
            let path = attr.path();
            if path.is_ident(LIMIT_ATTR) {
                let value = match attr {
                    Meta::NameValue(attr) => &attr.value,
                    _ => unimplemented!(),
                };

                limit = expr_to_u64(value).map(|u| u as usize);
            } else if path.is_ident(CONCURRENT_LIMIT_ATTR) {
                let value = match attr {
                    Meta::NameValue(attr) => &attr.value,
                    _ => unimplemented!(),
                };

                concurrent_limit = expr_to_u64(value).map(|u| u as usize);
            } else if path.is_ident(ASYNCHRONOUS_ATTR) {
                asynchronous = true;
            } else if path.is_ident(PASSTHROUGH_ATTR) {
                passthrough = true;
            } else if path.is_ident(WINDOW_ATTR) {
                let value = match attr {
                    Meta::NameValue(attr) => &attr.value,
                    _ => unimplemented!(),
                };

                let window_duration_ms = expr_to_u64(value);
                default_window = window_duration_ms;
            } else if let Some(ident) = path.get_ident().map(|i| i.to_string())
                && ident.starts_with(WINDOW_ATTR)
            {
                let value = match attr {
                    Meta::NameValue(attr) => &attr.value,
                    _ => unimplemented!(),
                };

                let call_size = ident.replace(WINDOW_ATTR, "");
                let call_size = call_size.parse::<u64>().unwrap();
                let call_window = expr_to_u64(value).expect("expected u64");

                let unsorted = windows
                    .iter()
                    .find(|(_call_size, _)| **_call_size > call_size);
                if unsorted.is_some() {
                    panic!("dynamic window call size must be sorted")
                }

                windows.insert(call_size, call_window);
            }
        }

        let default_window = default_window
            .expect("expected required attribute: window");

        Self {
            limit,
            concurrent_limit,
            asynchronous,
            passthrough,
            default_window,
            windows,
        }
    }
}