async_openai_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3use syn::{
4    parse::{Parse, ParseStream},
5    parse_macro_input,
6    punctuated::Punctuated,
7    token::Comma,
8    FnArg, GenericParam, Generics, ItemFn, Pat, PatType, TypeParam, WhereClause,
9};
10
11// Parse attribute arguments like #[byot(T0: Display + Debug, T1: Clone, R: Serialize)]
12struct BoundArgs {
13    bounds: Vec<(String, syn::TypeParamBound)>,
14    where_clause: Option<String>,
15    stream: bool, // Add stream flag
16}
17
18impl Parse for BoundArgs {
19    fn parse(input: ParseStream) -> syn::Result<Self> {
20        let mut bounds = Vec::new();
21        let mut where_clause = None;
22        let mut stream = false; // Default to false
23        let vars = Punctuated::<syn::MetaNameValue, Comma>::parse_terminated(input)?;
24
25        for var in vars {
26            let name = var.path.get_ident().unwrap().to_string();
27            match name.as_str() {
28                "where_clause" => {
29                    where_clause = Some(var.value.into_token_stream().to_string());
30                }
31                "stream" => {
32                    stream = var.value.into_token_stream().to_string().contains("true");
33                }
34                _ => {
35                    let bound: syn::TypeParamBound =
36                        syn::parse_str(&var.value.into_token_stream().to_string())?;
37                    bounds.push((name, bound));
38                }
39            }
40        }
41        Ok(BoundArgs {
42            bounds,
43            where_clause,
44            stream,
45        })
46    }
47}
48
49#[proc_macro_attribute]
50pub fn byot_passthrough(_args: TokenStream, item: TokenStream) -> TokenStream {
51    item
52}
53
54#[proc_macro_attribute]
55pub fn byot(args: TokenStream, item: TokenStream) -> TokenStream {
56    let bounds_args = parse_macro_input!(args as BoundArgs);
57    let input = parse_macro_input!(item as ItemFn);
58    let mut new_generics = Generics::default();
59    let mut param_count = 0;
60
61    // Process function arguments
62    let mut new_params = Vec::new();
63    let args = input
64        .sig
65        .inputs
66        .iter()
67        .map(|arg| {
68            match arg {
69                FnArg::Receiver(receiver) => receiver.to_token_stream(),
70                FnArg::Typed(PatType { pat, .. }) => {
71                    if let Pat::Ident(pat_ident) = &**pat {
72                        let generic_name = format!("T{}", param_count);
73                        let generic_ident =
74                            syn::Ident::new(&generic_name, proc_macro2::Span::call_site());
75
76                        // Create type parameter with optional bounds
77                        let mut type_param = TypeParam::from(generic_ident.clone());
78                        if let Some((_, bound)) = bounds_args
79                            .bounds
80                            .iter()
81                            .find(|(name, _)| name == &generic_name)
82                        {
83                            type_param.bounds.extend(vec![bound.clone()]);
84                        }
85
86                        new_params.push(GenericParam::Type(type_param));
87                        param_count += 1;
88                        quote! { #pat_ident: #generic_ident }
89                    } else {
90                        arg.to_token_stream()
91                    }
92                }
93            }
94        })
95        .collect::<Vec<_>>();
96
97    // Add R type parameter with optional bounds
98    let generic_r = syn::Ident::new("R", proc_macro2::Span::call_site());
99    let mut return_type_param = TypeParam::from(generic_r.clone());
100    if let Some((_, bound)) = bounds_args.bounds.iter().find(|(name, _)| name == "R") {
101        return_type_param.bounds.extend(vec![bound.clone()]);
102    }
103    new_params.push(GenericParam::Type(return_type_param));
104
105    // Add all generic parameters
106    new_generics.params.extend(new_params);
107
108    let fn_name = &input.sig.ident;
109    let byot_fn_name = syn::Ident::new(&format!("{}_byot", fn_name), fn_name.span());
110    let vis = &input.vis;
111    let block = &input.block;
112    let attrs = &input.attrs;
113    let asyncness = &input.sig.asyncness;
114
115    // Parse where clause if provided
116    let where_clause = if let Some(where_str) = bounds_args.where_clause {
117        match syn::parse_str::<WhereClause>(&format!("where {}", where_str.replace("\"", ""))) {
118            Ok(where_clause) => quote! { #where_clause },
119            Err(e) => return TokenStream::from(e.to_compile_error()),
120        }
121    } else {
122        quote! {}
123    };
124
125    // Generate return type based on stream flag
126    let return_type = if bounds_args.stream {
127        quote! { Result<::std::pin::Pin<Box<dyn ::futures::Stream<Item = Result<R, OpenAIError>> + Send>>, OpenAIError> }
128    } else {
129        quote! { Result<R, OpenAIError> }
130    };
131
132    let expanded = quote! {
133        #(#attrs)*
134        #input
135
136        #(#attrs)*
137        #vis #asyncness fn #byot_fn_name #new_generics (#(#args),*) -> #return_type #where_clause #block
138    };
139
140    expanded.into()
141}