openai_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Ident;
3use quote::{quote, ToTokens};
4use syn::parse::{Parse, ParseStream};
5use syn::FieldValue;
6use syn::{parse, parse_macro_input, Expr, Member, Token};
7
8enum MessageType {
9    Assistant,
10    User,
11    Function,
12    System,
13}
14
15impl Parse for MessageType {
16    fn parse(input: ParseStream) -> syn::Result<Self> {
17        let ident: Ident = input.parse()?;
18        Ok(match ident.to_string().as_str() {
19            "user" => MessageType::User,
20            "assistant" => MessageType::Assistant,
21            "function" => MessageType::Function,
22            "system" => MessageType::System,
23            _ => return Err(input.error("unexpected message type")),
24        })
25    }
26}
27
28struct Message {
29    message_type: MessageType,
30    user_name: Option<Expr>,
31    content: Option<Expr>,
32}
33
34impl Parse for Message {
35    fn parse(input: ParseStream) -> syn::Result<Self> {
36        macro_rules! find_pred {
37            ($l:expr) => {
38                |f: &&FieldValue| -> bool {
39                    if let Member::Named(name) = &f.member {
40                        name.to_string().as_str() == $l
41                    } else {
42                        false
43                    }
44                }
45            };
46        }
47
48        let fields = input.parse_terminated(FieldValue::parse, Token![,])?;
49
50        Ok(Message {
51            message_type: {
52                let fields = fields
53                    .iter()
54                    .filter(|f| f.colon_token.is_none())
55                    .cloned()
56                    .collect::<Vec<FieldValue>>();
57                let field = if fields.len() == 1 {
58                    fields[0].clone()
59                } else {
60                    return Err(input.error("message type not specified"));
61                };
62                let f_val: TokenStream = field.expr.clone().into_token_stream().into();
63                let m_type = parse(f_val)?;
64                m_type
65            },
66            user_name: {
67                let field = fields.iter().find(find_pred!("user_name"));
68                if let Some(f) = field {
69                    Some(f.expr.clone())
70                } else {
71                    None
72                }
73            },
74            content: {
75                let field = fields.iter().find(find_pred!("content"));
76                if let Some(f) = field {
77                    Some(f.expr.clone())
78                } else {
79                    None
80                }
81            },
82        })
83    }
84}
85
86#[proc_macro]
87pub fn message(input: TokenStream) -> TokenStream {
88    let message = parse_macro_input!(input as Message);
89
90    let m_type = match message.message_type {
91        MessageType::Assistant => "assistant",
92        MessageType::User => "user",
93        MessageType::Function => "function",
94        MessageType::System => "system",
95    };
96
97    let mut output = quote! {
98            use openai_utils::Message;
99            Message::new(#m_type)
100    };
101
102    if let Some(content) = message.content {
103        output = quote! {
104            #output.with_content(#content)
105        };
106    }
107
108    if let Some(user) = message.user_name {
109        output = quote! {
110            #output.with_user(#user)
111        };
112    }
113
114    output = quote! {
115        {
116            #output
117        }
118    };
119
120    output.into()
121}
122
123struct AiAgent {
124    model: Expr,
125    messages: Option<Expr>,
126    function_call: Option<Expr>,
127    temperature: Option<Expr>,
128    top_p: Option<Expr>,
129    n: Option<Expr>,
130    stop: Option<Expr>,
131    max_tokens: Option<Expr>,
132    presence_penalty: Option<Expr>,
133    frequency_penalty: Option<Expr>,
134    system_message: Option<Expr>,
135    logit_bias: Option<Expr>,
136    user: Option<Expr>,
137}
138
139impl Parse for AiAgent {
140    fn parse(input: ParseStream) -> syn::Result<Self> {
141        macro_rules! find_pred {
142            ($l:expr) => {
143                |f: &&FieldValue| -> bool {
144                    if let Member::Named(name) = &f.member {
145                        name.to_string().as_str() == $l
146                    } else {
147                        false
148                    }
149                }
150            };
151        }
152
153        let fields = input.parse_terminated(FieldValue::parse, Token![,])?;
154
155        let required_field = |l: &str| {
156            let fields = fields
157                .iter()
158                .filter(find_pred!(l))
159                .cloned()
160                .collect::<Vec<FieldValue>>();
161            if fields.len() == 1 {
162                let field = fields[0].clone();
163                let f_val: TokenStream = field.expr.clone().into_token_stream().into();
164                parse(f_val)
165            } else {
166                Err(input.error(format!("'{}' field not specified", l)))
167            }
168        };
169
170        let optional_field = |l: &str| {
171            let field = fields.iter().find(find_pred!(l));
172            if let Some(f) = field {
173                Some(f.expr.clone())
174            } else {
175                None
176            }
177        };
178
179        Ok(Self {
180            model: required_field("model")?,
181            messages: optional_field("messages"),
182            function_call: optional_field("function_call"),
183            temperature: optional_field("temperature"),
184            top_p: optional_field("top_p"),
185            n: optional_field("n"),
186            stop: optional_field("stop"),
187            max_tokens: optional_field("max_tokens"),
188            presence_penalty: optional_field("presence_penalty"),
189            frequency_penalty: optional_field("frequency_penalty"),
190            system_message: optional_field("system_message"),
191            logit_bias: optional_field("logit_bias"),
192            user: optional_field("user"),
193        })
194    }
195}
196
197#[proc_macro]
198pub fn ai_agent(input: TokenStream) -> TokenStream {
199    let input = parse_macro_input!(input as AiAgent);
200
201    let model = input.model;
202
203    let mut b = quote! {
204        openai_utils::AiAgent::new(#model)
205    };
206
207    if let Some(messages) = input.messages {
208        if let Expr::Array(messages) = messages {
209            b = quote! {
210                #b.with_messages(#messages.into())
211            };
212        } else {
213            b = quote! {
214                #b.with_messages([#messages].into())
215            };
216        }
217
218
219    } else {
220        b = quote! {
221            #b.with_messages(vec![])
222        };
223    }
224
225    if let Some(function_call) = input.function_call {
226        b = quote! {
227            #b.with_function_call(#function_call)
228        };
229    }
230
231    if let Some(temperature) = input.temperature {
232        b = quote! {
233            #b.with_temperature(#temperature)
234        };
235    }
236
237    if let Some(top_p) = input.top_p {
238        b = quote! {
239            #b.with_top_p(#top_p)
240        };
241    }
242
243    if let Some(n) = input.n {
244        b = quote! {
245            #b.with_n(#n)
246        };
247    }
248
249    if let Some(stop) = input.stop {
250        b = quote! {
251            #b.with_stop(#stop)
252        };
253    }
254
255    if let Some(max_tokens) = input.max_tokens {
256        b = quote! {
257            #b.with_max_tokens(#max_tokens)
258        };
259    }
260
261    if let Some(presence_penalty) = input.presence_penalty {
262        b = quote! {
263            #b.with_presence_penalty(#presence_penalty)
264        };
265    }
266
267    if let Some(frequency_penalty) = input.frequency_penalty {
268        b = quote! {
269            #b.with_frequency_penalty(#frequency_penalty)
270        };
271    }
272
273    if let Some(logit_bias) = input.logit_bias {
274        b = quote! {
275            #b.with_logit_bias(#logit_bias)
276        };
277    }
278
279    if let Some(system_message) = input.system_message {
280        b = quote! {
281            #b.with_system_message(#system_message)
282        };
283    }
284
285    if let Some(user) = input.user {
286        b = quote! {
287            #b.with_user(#user)
288        };
289    }
290
291
292    b.into()
293}