outfox_openai_macros/
lib.rs1use proc_macro::TokenStream;
3use quote::{ToTokens, quote};
4use syn::parse::{Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::token::Comma;
7use syn::{
8 FnArg, GenericParam, Generics, ItemFn, Pat, PatType, TypeParam, WhereClause, parse_macro_input,
9};
10
11struct BoundArgs {
13 bounds: Vec<(String, syn::TypeParamBound)>,
14 where_clause: Option<String>,
15 stream: bool, }
17
18impl Parse for BoundArgs {
19 fn parse(input: ParseStream) -> syn::Result<Self> {
20 let mut bounds = Vec::new();
21 let mut where_clause = None;
22 let mut stream = false; let vars = Punctuated::<syn::MetaNameValue, Comma>::parse_terminated(input)?;
24
25 for var in vars {
26 let name = var
27 .path
28 .get_ident()
29 .expect("expected identifier")
30 .to_string();
31 match name.as_str() {
32 "where_clause" => {
33 where_clause = Some(var.value.into_token_stream().to_string());
34 }
35 "stream" => {
36 stream = var.value.into_token_stream().to_string().contains("true");
37 }
38 _ => {
39 let bound: syn::TypeParamBound =
40 syn::parse_str(&var.value.into_token_stream().to_string())?;
41 bounds.push((name, bound));
42 }
43 }
44 }
45 Ok(Self {
46 bounds,
47 where_clause,
48 stream,
49 })
50 }
51}
52
53#[proc_macro_attribute]
55pub fn byot_passthrough(_args: TokenStream, item: TokenStream) -> TokenStream {
56 item
57}
58
59#[proc_macro_attribute]
61pub fn byot(args: TokenStream, item: TokenStream) -> TokenStream {
62 let bounds_args = parse_macro_input!(args as BoundArgs);
63 let input = parse_macro_input!(item as ItemFn);
64 let mut new_generics = Generics::default();
65 let mut param_count = 0;
66
67 let mut new_params = Vec::new();
69 let args = input
70 .sig
71 .inputs
72 .iter()
73 .map(|arg| {
74 match arg {
75 FnArg::Receiver(receiver) => receiver.to_token_stream(),
76 FnArg::Typed(PatType { pat, .. }) => {
77 if let Pat::Ident(pat_ident) = &**pat {
78 let generic_name = format!("T{param_count}");
79 let generic_ident =
80 syn::Ident::new(&generic_name, proc_macro2::Span::call_site());
81
82 let mut type_param = TypeParam::from(generic_ident.clone());
84 if let Some((_, bound)) = bounds_args
85 .bounds
86 .iter()
87 .find(|(name, _)| name == &generic_name)
88 {
89 type_param.bounds.extend(vec![bound.clone()]);
90 }
91
92 new_params.push(GenericParam::Type(type_param));
93 param_count += 1;
94 quote! { #pat_ident: #generic_ident }
95 } else {
96 arg.to_token_stream()
97 }
98 }
99 }
100 })
101 .collect::<Vec<_>>();
102
103 let generic_r = syn::Ident::new("R", proc_macro2::Span::call_site());
105 let mut return_type_param = TypeParam::from(generic_r);
106 if let Some((_, bound)) = bounds_args.bounds.iter().find(|(name, _)| name == "R") {
107 return_type_param.bounds.extend(vec![bound.clone()]);
108 }
109 new_params.push(GenericParam::Type(return_type_param));
110
111 new_generics.params.extend(new_params);
113
114 let fn_name = &input.sig.ident;
115 let byot_fn_name = syn::Ident::new(&format!("{fn_name}_byot"), fn_name.span());
116 let vis = &input.vis;
117 let block = &input.block;
118 let attrs = &input.attrs;
119 let asyncness = &input.sig.asyncness;
120
121 let where_clause = if let Some(where_str) = bounds_args.where_clause {
123 match syn::parse_str::<WhereClause>(&format!("where {}", where_str.replace("\"", ""))) {
124 Ok(where_clause) => quote! { #where_clause },
125 Err(e) => return TokenStream::from(e.to_compile_error()),
126 }
127 } else {
128 quote! {}
129 };
130
131 let return_type = if bounds_args.stream {
133 quote! { Result<::std::pin::Pin<Box<dyn ::futures::Stream<Item = Result<R, OpenAIError>> + Send>>, OpenAIError> }
134 } else {
135 quote! { Result<R, OpenAIError> }
136 };
137
138 let expanded = quote! {
139 #input
140
141 #(#attrs)*
142 #vis #asyncness fn #byot_fn_name #new_generics (#(#args),*) -> #return_type #where_clause #block
143 };
144
145 expanded.into()
146}