use std::iter::once;
use proc_macro2::{TokenStream, TokenTree};
use quote::{quote, format_ident};
use syn::{
parse::{Parse, ParseStream},
parse2,
punctuated::Punctuated,
Expr, Ident, Token, Result,
};
struct TaskEntry {
freq: Expr,
func: Expr,
}
impl Parse for TaskEntry {
fn parse(input: ParseStream) -> Result<Self> {
let mut tokens = TokenStream::new();
while !input.is_empty() && !input.peek(Token![->]) {
tokens.extend(once(input.parse::<TokenTree>()?));
}
let freq: Expr = parse2(tokens)?;
let _: Token![->] = input.parse()?;
let func: Expr = input.parse()?;
Ok(TaskEntry { freq, func })
}
}
struct ExecutorEntry {
name: Ident,
tasks: Punctuated<TaskEntry, Token![,]>,
}
impl Parse for ExecutorEntry {
fn parse(input: ParseStream) -> Result<Self> {
let name: Ident = input.parse()?;
let _: Token![=>] = input.parse()?;
let content;
syn::braced!(content in input);
let tasks = content.parse_terminated(TaskEntry::parse, Token![,])?;
Ok(ExecutorEntry { name, tasks })
}
}
struct AddInput(Punctuated<ExecutorEntry, Token![,]>);
impl Parse for AddInput {
fn parse(input: ParseStream) -> Result<Self> {
Ok(AddInput(input.parse_terminated(ExecutorEntry::parse, Token![,])?))
}
}
pub fn task(input: TokenStream) -> TokenStream {
let input = match parse2::<AddInput>(input) {
Ok(val) => val,
Err(err) => return err.to_compile_error(),
};
let mut declarations = Vec::new();
for (exec_idx, exec_entry) in input.0.iter().enumerate() {
let exec_name = &exec_entry.name;
for (task_idx, task) in exec_entry.tasks.iter().enumerate() {
let var_name = format_ident!("__node_{}_{}", exec_idx, task_idx);
let func = &task.func;
let freq = &task.freq;
declarations.push(quote! {
let mut #var_name = TaskNode::new(Box::pin(#func), #freq);
#exec_name.add(&mut #var_name);
});
}
}
quote! {
#(#declarations)*
}
}
pub mod join {
use super::*;
struct JoinInput {
exprs: Punctuated<Expr, Token![,]>,
}
impl Parse for JoinInput {
fn parse(input: ParseStream) -> Result<Self> {
let exprs = Punctuated::parse_terminated(input)?;
Ok(JoinInput { exprs })
}
}
pub fn join(input: TokenStream, is_try: bool) -> TokenStream {
let input = match parse2::<JoinInput>(input) {
Ok(res) => res,
Err(err) => return err.to_compile_error(),
};
let mut expr_list: Vec<Expr> = input.exprs.into_iter().collect();
if expr_list.is_empty() {
return quote! { async {} };
}
let tree = build_join_tree(&mut expr_list, is_try);
quote! { async move { #tree.await } }
}
fn build_join_tree(exprs: &mut Vec<Expr>, is_try: bool) -> TokenStream {
if exprs.len() == 1 {
let last = exprs.pop().unwrap();
return quote! { #last };
}
let head = exprs.remove(0);
let tail = build_join_tree(exprs, is_try);
let class = if is_try {
quote! { ::uefi_async::nano_alloc::control::single::join::TryJoin }
} else {
quote! { ::uefi_async::nano_alloc::control::single::join::Join }
};
quote! {
#class {
head: #head,
tail: #tail,
head_done: false,
tail_done: false,
}
}
}
pub fn join_all(input: TokenStream) -> TokenStream {
let input = match parse2::<JoinInput>(input) {
Ok(res) => res,
Err(err) => return err.to_compile_error(),
};
let exprs: Vec<Expr> = input.exprs.into_iter().collect();
let count = exprs.len();
if count == 0 {
return quote! { async { () } };
}
if count == 1 {
let f = &exprs[0];
return quote! { async move { #f.await } };
}
let tree = build_join_all_tree(&exprs);
let res_idents: Vec<Ident> = (0..count)
.map(|i| format_ident!("__res_{}", i))
.collect();
let mut pattern = {
let last_ident = &res_idents[count - 1];
quote! { #last_ident }
};
for i in (0..count - 1).rev() {
let id = &res_idents[i];
let prev_pattern = pattern;
pattern = quote! { (#id, #prev_pattern) };
}
quote! {
async move {
let #pattern = #tree.await;
( #(#res_idents),* )
}
}
}
fn build_join_all_tree(exprs: &[Expr]) -> TokenStream {
let count = exprs.len();
let head = &exprs[0];
if count == 2 {
let tail = &exprs[1];
return quote! {
::uefi_async::nano_alloc::control::single::join::JoinAll::new(#head, #tail)
};
}
let tail_tree = build_join_all_tree(&exprs[1..]);
quote! {
::uefi_async::nano_alloc::control::single::join::JoinAll::new(#head, #tail_tree)
}
}
}