1use proc_macro::TokenStream;
2use proc_macro2::Ident;
3use quote::{quote, ToTokens};
4use syn::parse::{Parse, ParseStream};
5use syn::FieldValue;
6use syn::{parse, parse_macro_input, Expr, Member, Token};
7
8enum MessageType {
9 Assistant,
10 User,
11 Function,
12 System,
13}
14
15impl Parse for MessageType {
16 fn parse(input: ParseStream) -> syn::Result<Self> {
17 let ident: Ident = input.parse()?;
18 Ok(match ident.to_string().as_str() {
19 "user" => MessageType::User,
20 "assistant" => MessageType::Assistant,
21 "function" => MessageType::Function,
22 "system" => MessageType::System,
23 _ => return Err(input.error("unexpected message type")),
24 })
25 }
26}
27
28struct Message {
29 message_type: MessageType,
30 user_name: Option<Expr>,
31 content: Option<Expr>,
32}
33
34impl Parse for Message {
35 fn parse(input: ParseStream) -> syn::Result<Self> {
36 macro_rules! find_pred {
37 ($l:expr) => {
38 |f: &&FieldValue| -> bool {
39 if let Member::Named(name) = &f.member {
40 name.to_string().as_str() == $l
41 } else {
42 false
43 }
44 }
45 };
46 }
47
48 let fields = input.parse_terminated(FieldValue::parse, Token![,])?;
49
50 Ok(Message {
51 message_type: {
52 let fields = fields
53 .iter()
54 .filter(|f| f.colon_token.is_none())
55 .cloned()
56 .collect::<Vec<FieldValue>>();
57 let field = if fields.len() == 1 {
58 fields[0].clone()
59 } else {
60 return Err(input.error("message type not specified"));
61 };
62 let f_val: TokenStream = field.expr.clone().into_token_stream().into();
63 let m_type = parse(f_val)?;
64 m_type
65 },
66 user_name: {
67 let field = fields.iter().find(find_pred!("user_name"));
68 if let Some(f) = field {
69 Some(f.expr.clone())
70 } else {
71 None
72 }
73 },
74 content: {
75 let field = fields.iter().find(find_pred!("content"));
76 if let Some(f) = field {
77 Some(f.expr.clone())
78 } else {
79 None
80 }
81 },
82 })
83 }
84}
85
86#[proc_macro]
87pub fn message(input: TokenStream) -> TokenStream {
88 let message = parse_macro_input!(input as Message);
89
90 let m_type = match message.message_type {
91 MessageType::Assistant => "assistant",
92 MessageType::User => "user",
93 MessageType::Function => "function",
94 MessageType::System => "system",
95 };
96
97 let mut output = quote! {
98 use openai_utils::Message;
99 Message::new(#m_type)
100 };
101
102 if let Some(content) = message.content {
103 output = quote! {
104 #output.with_content(#content)
105 };
106 }
107
108 if let Some(user) = message.user_name {
109 output = quote! {
110 #output.with_user(#user)
111 };
112 }
113
114 output = quote! {
115 {
116 #output
117 }
118 };
119
120 output.into()
121}
122
123struct AiAgent {
124 model: Expr,
125 messages: Option<Expr>,
126 function_call: Option<Expr>,
127 temperature: Option<Expr>,
128 top_p: Option<Expr>,
129 n: Option<Expr>,
130 stop: Option<Expr>,
131 max_tokens: Option<Expr>,
132 presence_penalty: Option<Expr>,
133 frequency_penalty: Option<Expr>,
134 system_message: Option<Expr>,
135 logit_bias: Option<Expr>,
136 user: Option<Expr>,
137}
138
139impl Parse for AiAgent {
140 fn parse(input: ParseStream) -> syn::Result<Self> {
141 macro_rules! find_pred {
142 ($l:expr) => {
143 |f: &&FieldValue| -> bool {
144 if let Member::Named(name) = &f.member {
145 name.to_string().as_str() == $l
146 } else {
147 false
148 }
149 }
150 };
151 }
152
153 let fields = input.parse_terminated(FieldValue::parse, Token![,])?;
154
155 let required_field = |l: &str| {
156 let fields = fields
157 .iter()
158 .filter(find_pred!(l))
159 .cloned()
160 .collect::<Vec<FieldValue>>();
161 if fields.len() == 1 {
162 let field = fields[0].clone();
163 let f_val: TokenStream = field.expr.clone().into_token_stream().into();
164 parse(f_val)
165 } else {
166 Err(input.error(format!("'{}' field not specified", l)))
167 }
168 };
169
170 let optional_field = |l: &str| {
171 let field = fields.iter().find(find_pred!(l));
172 if let Some(f) = field {
173 Some(f.expr.clone())
174 } else {
175 None
176 }
177 };
178
179 Ok(Self {
180 model: required_field("model")?,
181 messages: optional_field("messages"),
182 function_call: optional_field("function_call"),
183 temperature: optional_field("temperature"),
184 top_p: optional_field("top_p"),
185 n: optional_field("n"),
186 stop: optional_field("stop"),
187 max_tokens: optional_field("max_tokens"),
188 presence_penalty: optional_field("presence_penalty"),
189 frequency_penalty: optional_field("frequency_penalty"),
190 system_message: optional_field("system_message"),
191 logit_bias: optional_field("logit_bias"),
192 user: optional_field("user"),
193 })
194 }
195}
196
197#[proc_macro]
198pub fn ai_agent(input: TokenStream) -> TokenStream {
199 let input = parse_macro_input!(input as AiAgent);
200
201 let model = input.model;
202
203 let mut b = quote! {
204 openai_utils::AiAgent::new(#model)
205 };
206
207 if let Some(messages) = input.messages {
208 if let Expr::Array(messages) = messages {
209 b = quote! {
210 #b.with_messages(#messages.into())
211 };
212 } else {
213 b = quote! {
214 #b.with_messages([#messages].into())
215 };
216 }
217
218
219 } else {
220 b = quote! {
221 #b.with_messages(vec![])
222 };
223 }
224
225 if let Some(function_call) = input.function_call {
226 b = quote! {
227 #b.with_function_call(#function_call)
228 };
229 }
230
231 if let Some(temperature) = input.temperature {
232 b = quote! {
233 #b.with_temperature(#temperature)
234 };
235 }
236
237 if let Some(top_p) = input.top_p {
238 b = quote! {
239 #b.with_top_p(#top_p)
240 };
241 }
242
243 if let Some(n) = input.n {
244 b = quote! {
245 #b.with_n(#n)
246 };
247 }
248
249 if let Some(stop) = input.stop {
250 b = quote! {
251 #b.with_stop(#stop)
252 };
253 }
254
255 if let Some(max_tokens) = input.max_tokens {
256 b = quote! {
257 #b.with_max_tokens(#max_tokens)
258 };
259 }
260
261 if let Some(presence_penalty) = input.presence_penalty {
262 b = quote! {
263 #b.with_presence_penalty(#presence_penalty)
264 };
265 }
266
267 if let Some(frequency_penalty) = input.frequency_penalty {
268 b = quote! {
269 #b.with_frequency_penalty(#frequency_penalty)
270 };
271 }
272
273 if let Some(logit_bias) = input.logit_bias {
274 b = quote! {
275 #b.with_logit_bias(#logit_bias)
276 };
277 }
278
279 if let Some(system_message) = input.system_message {
280 b = quote! {
281 #b.with_system_message(#system_message)
282 };
283 }
284
285 if let Some(user) = input.user {
286 b = quote! {
287 #b.with_user(#user)
288 };
289 }
290
291
292 b.into()
293}