batched_derive 0.1.3

rust macro util for batching expensive operations
Documentation
use std::usize;

use proc_macro2::TokenStream;
use quote::{format_ident, quote};

use crate::types::{Attributes, Function};

static MAX_SEMAPHORE_PERMITS: usize = 2305843009213693951;

pub fn build_code(call_function: Function, options: Attributes) -> TokenStream {
    let name = call_function.identifier.replace("_batched", "");
    let visibility = call_function.visibility;
    let arg = call_function.batched_arg;
    let arg_name = call_function.batched_arg_name;
    let arg_type = call_function.batched_arg_type;
    let returned = call_function.return_value;
    let inner_body = call_function.inner;

    let capacity = options.limit;
    let window = options.window;
    let concurrent_limit = options.concurrent_limit.unwrap_or(MAX_SEMAPHORE_PERMITS);

    let batched_producer_channel =
        format_ident!("BATCHED_{}_PRODUCER_CHANNEL", name.to_uppercase());
    let __spawn_background_batch = format_ident!("__spawn_background_{name}_batched");
    let batched = format_ident!("__{name}_batched");

    let name = syn::parse_str::<TokenStream>(&name).unwrap();
    let arg_name = syn::parse_str::<TokenStream>(&arg_name).unwrap();
    let fnname_multiple = format_ident!("{name}_multiple");
    let returned_arc = if options.wrap_in_arc {
        syn::parse_str(&format!("::std::sync::Arc<{returned}>")).unwrap()
    } else {
        returned.clone()
    };
    let wrap_in_arc = if options.wrap_in_arc {
        Some(syn::parse_str::<TokenStream>("let result = ::std::sync::Arc::new(result);").unwrap())
    } else {
        None
    };

    quote! {
        static #batched_producer_channel:
            ::tokio::sync::OnceCell<::tokio::sync::mpsc::Sender<(Vec<#arg_type>, ::tokio::sync::mpsc::Sender<#returned_arc>)>> = ::tokio::sync::OnceCell::const_new();

        async fn #__spawn_background_batch() -> ::tokio::sync::mpsc::Sender<(Vec<#arg_type>, ::tokio::sync::mpsc::Sender<#returned_arc>)> {
            let capacity = #capacity;
            let window = tokio::time::Duration::from_millis(#window);

            let (sender, mut receiver) = tokio::sync::mpsc::channel(capacity);
            tokio::task::spawn(async move {
                let mut buffer = Vec::with_capacity(capacity);
                let mut channels: Vec<::tokio::sync::mpsc::Sender<#returned_arc>> = vec![];
                let semaphore = ::std::sync::Arc::new(::tokio::sync::Semaphore::new(#concurrent_limit));

                loop {
                    let mut timer = tokio::time::interval(window);
                    timer.tick().await;

                    loop {
                        tokio::select! {
                            event = receiver.recv() => {
                                if event.is_none() {
                                    return;
                                }

                                let (mut calls, channel) = event.unwrap();
                                buffer.append(&mut calls);
                                channels.push(channel);
                                if buffer.len() >= capacity {
                                    break;
                                }
                            }

                            _ = timer.tick() => {
                                break;
                            }
                        }
                    }

                    let mut calls = vec![];
                    let mut return_channels = vec![];

                    std::mem::swap(&mut calls, &mut buffer);
                    std::mem::swap(&mut return_channels, &mut channels);
                    if calls.is_empty() {
                        continue
                    }

                    let permit = semaphore.clone().acquire_owned().await.unwrap();
                    tokio::task::spawn(async move {
                        let _permit = permit;

                        let result = #batched(calls).await;
                        #wrap_in_arc
                        for channel in return_channels {
                            let _ = channel.try_send(result.clone());
                        }
                    });

                    buffer.reserve(capacity);
                }
            });

            sender
        }

        async fn #batched(#arg) -> #returned {
            #inner_body
        }

        #visibility async fn #name(#arg_name: #arg_type) -> #returned_arc {
            #fnname_multiple(vec![#arg_name]).await
        }

        #visibility async fn #fnname_multiple(#arg_name: Vec<#arg_type>) -> #returned_arc {
            let channel = &#batched_producer_channel;
            let channel = channel.get_or_init(async || { #__spawn_background_batch().await }).await;

            let (response_channel_sender, mut response_channel_recv) = tokio::sync::mpsc::channel(1);
            channel.send((#arg_name, response_channel_sender)).await
                .expect("batched background thread is gone");

            let result = response_channel_recv.recv().await.expect("task panicked");
            result
        }
    }
}