use proc_macro::TokenStream;
use quote::{ToTokens, quote};
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::token::Comma;
use syn::{
FnArg, GenericParam, Generics, ItemFn, Pat, PatType, TypeParam, WhereClause, parse_macro_input,
};
struct BoundArgs {
bounds: Vec<(String, syn::TypeParamBound)>,
where_clause: Option<String>,
stream: bool, }
impl Parse for BoundArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut bounds = Vec::new();
let mut where_clause = None;
let mut stream = false; let vars = Punctuated::<syn::MetaNameValue, Comma>::parse_terminated(input)?;
for var in vars {
let name = var
.path
.get_ident()
.expect("expected identifier")
.to_string();
match name.as_str() {
"where_clause" => {
where_clause = Some(var.value.into_token_stream().to_string());
}
"stream" => {
stream = var.value.into_token_stream().to_string().contains("true");
}
_ => {
let bound: syn::TypeParamBound =
syn::parse_str(&var.value.into_token_stream().to_string())?;
bounds.push((name, bound));
}
}
}
Ok(Self {
bounds,
where_clause,
stream,
})
}
}
#[proc_macro_attribute]
pub fn byot_passthrough(_args: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[proc_macro_attribute]
pub fn byot(args: TokenStream, item: TokenStream) -> TokenStream {
let bounds_args = parse_macro_input!(args as BoundArgs);
let input = parse_macro_input!(item as ItemFn);
let mut new_generics = Generics::default();
let mut param_count = 0;
let mut new_params = Vec::new();
let args = input
.sig
.inputs
.iter()
.map(|arg| {
match arg {
FnArg::Receiver(receiver) => receiver.to_token_stream(),
FnArg::Typed(PatType { pat, .. }) => {
if let Pat::Ident(pat_ident) = &**pat {
let generic_name = format!("T{param_count}");
let generic_ident =
syn::Ident::new(&generic_name, proc_macro2::Span::call_site());
let mut type_param = TypeParam::from(generic_ident.clone());
if let Some((_, bound)) = bounds_args
.bounds
.iter()
.find(|(name, _)| name == &generic_name)
{
type_param.bounds.extend(vec![bound.clone()]);
}
new_params.push(GenericParam::Type(type_param));
param_count += 1;
quote! { #pat_ident: #generic_ident }
} else {
arg.to_token_stream()
}
}
}
})
.collect::<Vec<_>>();
let generic_r = syn::Ident::new("R", proc_macro2::Span::call_site());
let mut return_type_param = TypeParam::from(generic_r);
if let Some((_, bound)) = bounds_args.bounds.iter().find(|(name, _)| name == "R") {
return_type_param.bounds.extend(vec![bound.clone()]);
}
new_params.push(GenericParam::Type(return_type_param));
new_generics.params.extend(new_params);
let fn_name = &input.sig.ident;
let byot_fn_name = syn::Ident::new(&format!("{fn_name}_byot"), fn_name.span());
let vis = &input.vis;
let block = &input.block;
let attrs = &input.attrs;
let asyncness = &input.sig.asyncness;
let where_clause = if let Some(where_str) = bounds_args.where_clause {
match syn::parse_str::<WhereClause>(&format!("where {}", where_str.replace("\"", ""))) {
Ok(where_clause) => quote! { #where_clause },
Err(e) => return TokenStream::from(e.to_compile_error()),
}
} else {
quote! {}
};
let return_type = if bounds_args.stream {
quote! { Result<::std::pin::Pin<Box<dyn ::futures::Stream<Item = Result<R, OpenAIError>> + Send>>, OpenAIError> }
} else {
quote! { Result<R, OpenAIError> }
};
let expanded = quote! {
#input
#(#attrs)*
#vis #asyncness fn #byot_fn_name #new_generics (#(#args),*) -> #return_type #where_clause #block
};
expanded.into()
}