async_openai_macros/
lib.rs1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3use syn::{
4 parse::{Parse, ParseStream},
5 parse_macro_input,
6 punctuated::Punctuated,
7 token::Comma,
8 FnArg, GenericParam, Generics, ItemFn, Pat, PatType, TypeParam, WhereClause,
9};
10
11struct BoundArgs {
13 bounds: Vec<(String, syn::TypeParamBound)>,
14 where_clause: Option<String>,
15 stream: bool, }
17
18impl Parse for BoundArgs {
19 fn parse(input: ParseStream) -> syn::Result<Self> {
20 let mut bounds = Vec::new();
21 let mut where_clause = None;
22 let mut stream = false; let vars = Punctuated::<syn::MetaNameValue, Comma>::parse_terminated(input)?;
24
25 for var in vars {
26 let name = var.path.get_ident().unwrap().to_string();
27 match name.as_str() {
28 "where_clause" => {
29 where_clause = Some(var.value.into_token_stream().to_string());
30 }
31 "stream" => {
32 stream = var.value.into_token_stream().to_string().contains("true");
33 }
34 _ => {
35 let bound: syn::TypeParamBound =
36 syn::parse_str(&var.value.into_token_stream().to_string())?;
37 bounds.push((name, bound));
38 }
39 }
40 }
41 Ok(BoundArgs {
42 bounds,
43 where_clause,
44 stream,
45 })
46 }
47}
48
49#[proc_macro_attribute]
50pub fn byot_passthrough(_args: TokenStream, item: TokenStream) -> TokenStream {
51 item
52}
53
54#[proc_macro_attribute]
55pub fn byot(args: TokenStream, item: TokenStream) -> TokenStream {
56 let bounds_args = parse_macro_input!(args as BoundArgs);
57 let input = parse_macro_input!(item as ItemFn);
58 let mut new_generics = Generics::default();
59 let mut param_count = 0;
60
61 let mut new_params = Vec::new();
63 let args = input
64 .sig
65 .inputs
66 .iter()
67 .map(|arg| {
68 match arg {
69 FnArg::Receiver(receiver) => receiver.to_token_stream(),
70 FnArg::Typed(PatType { pat, .. }) => {
71 if let Pat::Ident(pat_ident) = &**pat {
72 let generic_name = format!("T{}", param_count);
73 let generic_ident =
74 syn::Ident::new(&generic_name, proc_macro2::Span::call_site());
75
76 let mut type_param = TypeParam::from(generic_ident.clone());
78 if let Some((_, bound)) = bounds_args
79 .bounds
80 .iter()
81 .find(|(name, _)| name == &generic_name)
82 {
83 type_param.bounds.extend(vec![bound.clone()]);
84 }
85
86 new_params.push(GenericParam::Type(type_param));
87 param_count += 1;
88 quote! { #pat_ident: #generic_ident }
89 } else {
90 arg.to_token_stream()
91 }
92 }
93 }
94 })
95 .collect::<Vec<_>>();
96
97 let generic_r = syn::Ident::new("R", proc_macro2::Span::call_site());
99 let mut return_type_param = TypeParam::from(generic_r.clone());
100 if let Some((_, bound)) = bounds_args.bounds.iter().find(|(name, _)| name == "R") {
101 return_type_param.bounds.extend(vec![bound.clone()]);
102 }
103 new_params.push(GenericParam::Type(return_type_param));
104
105 new_generics.params.extend(new_params);
107
108 let fn_name = &input.sig.ident;
109 let byot_fn_name = syn::Ident::new(&format!("{}_byot", fn_name), fn_name.span());
110 let vis = &input.vis;
111 let block = &input.block;
112 let attrs = &input.attrs;
113 let asyncness = &input.sig.asyncness;
114
115 let where_clause = if let Some(where_str) = bounds_args.where_clause {
117 match syn::parse_str::<WhereClause>(&format!("where {}", where_str.replace("\"", ""))) {
118 Ok(where_clause) => quote! { #where_clause },
119 Err(e) => return TokenStream::from(e.to_compile_error()),
120 }
121 } else {
122 quote! {}
123 };
124
125 let return_type = if bounds_args.stream {
127 quote! { Result<::std::pin::Pin<Box<dyn ::futures::Stream<Item = Result<R, OpenAIError>> + Send>>, OpenAIError> }
128 } else {
129 quote! { Result<R, OpenAIError> }
130 };
131
132 let expanded = quote! {
133 #(#attrs)*
134 #input
135
136 #(#attrs)*
137 #vis #asyncness fn #byot_fn_name #new_generics (#(#args),*) -> #return_type #where_clause #block
138 };
139
140 expanded.into()
141}