Skip to main content

outfox_openai_macros/
lib.rs

1//! Macros for the OpenAI crate.
2use proc_macro::TokenStream;
3use quote::{ToTokens, quote};
4use syn::parse::{Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::token::Comma;
7use syn::{
8    FnArg, GenericParam, Generics, ItemFn, Pat, PatType, TypeParam, WhereClause, parse_macro_input,
9};
10
11// Parse attribute arguments like #[byot(T0: Display + Debug, T1: Clone, R: Serialize)]
12struct BoundArgs {
13    bounds: Vec<(String, syn::TypeParamBound)>,
14    where_clause: Option<String>,
15    stream: bool, // Add stream flag
16}
17
18impl Parse for BoundArgs {
19    fn parse(input: ParseStream) -> syn::Result<Self> {
20        let mut bounds = Vec::new();
21        let mut where_clause = None;
22        let mut stream = false; // Default to false
23        let vars = Punctuated::<syn::MetaNameValue, Comma>::parse_terminated(input)?;
24
25        for var in vars {
26            let name = var
27                .path
28                .get_ident()
29                .expect("expected identifier")
30                .to_string();
31            match name.as_str() {
32                "where_clause" => {
33                    where_clause = Some(var.value.into_token_stream().to_string());
34                }
35                "stream" => {
36                    stream = var.value.into_token_stream().to_string().contains("true");
37                }
38                _ => {
39                    let bound: syn::TypeParamBound =
40                        syn::parse_str(&var.value.into_token_stream().to_string())?;
41                    bounds.push((name, bound));
42                }
43            }
44        }
45        Ok(Self {
46            bounds,
47            where_clause,
48            stream,
49        })
50    }
51}
52
53/// A passthrough macro that does nothing, used to disable byot processing.
54#[proc_macro_attribute]
55pub fn byot_passthrough(_args: TokenStream, item: TokenStream) -> TokenStream {
56    item
57}
58
59/// A macro to generate "bring your own type" (byot) functions with generic type parameters.
60#[proc_macro_attribute]
61pub fn byot(args: TokenStream, item: TokenStream) -> TokenStream {
62    let bounds_args = parse_macro_input!(args as BoundArgs);
63    let input = parse_macro_input!(item as ItemFn);
64    let mut new_generics = Generics::default();
65    let mut param_count = 0;
66
67    // Process function arguments
68    let mut new_params = Vec::new();
69    let args = input
70        .sig
71        .inputs
72        .iter()
73        .map(|arg| {
74            match arg {
75                FnArg::Receiver(receiver) => receiver.to_token_stream(),
76                FnArg::Typed(PatType { pat, .. }) => {
77                    if let Pat::Ident(pat_ident) = &**pat {
78                        let generic_name = format!("T{param_count}");
79                        let generic_ident =
80                            syn::Ident::new(&generic_name, proc_macro2::Span::call_site());
81
82                        // Create type parameter with optional bounds
83                        let mut type_param = TypeParam::from(generic_ident.clone());
84                        if let Some((_, bound)) = bounds_args
85                            .bounds
86                            .iter()
87                            .find(|(name, _)| name == &generic_name)
88                        {
89                            type_param.bounds.extend(vec![bound.clone()]);
90                        }
91
92                        new_params.push(GenericParam::Type(type_param));
93                        param_count += 1;
94                        quote! { #pat_ident: #generic_ident }
95                    } else {
96                        arg.to_token_stream()
97                    }
98                }
99            }
100        })
101        .collect::<Vec<_>>();
102
103    // Add R type parameter with optional bounds
104    let generic_r = syn::Ident::new("R", proc_macro2::Span::call_site());
105    let mut return_type_param = TypeParam::from(generic_r);
106    if let Some((_, bound)) = bounds_args.bounds.iter().find(|(name, _)| name == "R") {
107        return_type_param.bounds.extend(vec![bound.clone()]);
108    }
109    new_params.push(GenericParam::Type(return_type_param));
110
111    // Add all generic parameters
112    new_generics.params.extend(new_params);
113
114    let fn_name = &input.sig.ident;
115    let byot_fn_name = syn::Ident::new(&format!("{fn_name}_byot"), fn_name.span());
116    let vis = &input.vis;
117    let block = &input.block;
118    let attrs = &input.attrs;
119    let asyncness = &input.sig.asyncness;
120
121    // Parse where clause if provided
122    let where_clause = if let Some(where_str) = bounds_args.where_clause {
123        match syn::parse_str::<WhereClause>(&format!("where {}", where_str.replace("\"", ""))) {
124            Ok(where_clause) => quote! { #where_clause },
125            Err(e) => return TokenStream::from(e.to_compile_error()),
126        }
127    } else {
128        quote! {}
129    };
130
131    // Generate return type based on stream flag
132    let return_type = if bounds_args.stream {
133        quote! { Result<::std::pin::Pin<Box<dyn ::futures::Stream<Item = Result<R, OpenAIError>> + Send>>, OpenAIError> }
134    } else {
135        quote! { Result<R, OpenAIError> }
136    };
137
138    let expanded = quote! {
139        #input
140
141        #(#attrs)*
142        #vis #asyncness fn #byot_fn_name #new_generics (#(#args),*) -> #return_type #where_clause #block
143    };
144
145    expanded.into()
146}