llm_toolkit_macros/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Meta, parse_macro_input, punctuated::Punctuated};
4
5fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
7 attrs
8 .iter()
9 .filter_map(|attr| {
10 if attr.path().is_ident("doc")
11 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
12 && let syn::Expr::Lit(syn::ExprLit {
13 lit: syn::Lit::Str(lit_str),
14 ..
15 }) = &meta_name_value.value
16 {
17 return Some(lit_str.value());
18 }
19 None
20 })
21 .map(|s| s.trim().to_string())
22 .collect::<Vec<_>>()
23 .join(" ")
24}
25
26enum PromptAttribute {
28 Skip,
29 Description(String),
30 None,
31}
32
33fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
35 for attr in attrs {
36 if attr.path().is_ident("prompt") {
37 if let Ok(meta_list) = attr.meta.require_list() {
39 let tokens = &meta_list.tokens;
40 let tokens_str = tokens.to_string();
41 if tokens_str == "skip" {
42 return PromptAttribute::Skip;
43 }
44 }
45
46 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
48 return PromptAttribute::Description(lit_str.value());
49 }
50 }
51 }
52 PromptAttribute::None
53}
54
55#[derive(Debug, Default)]
57struct FieldPromptAttrs {
58 skip: bool,
59 rename: Option<String>,
60 format_with: Option<String>,
61 image: bool,
62}
63
64fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
66 let mut result = FieldPromptAttrs::default();
67
68 for attr in attrs {
69 if attr.path().is_ident("prompt") {
70 if let Ok(meta_list) = attr.meta.require_list() {
72 if let Ok(metas) =
74 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
75 {
76 for meta in metas {
77 match meta {
78 Meta::Path(path) if path.is_ident("skip") => {
79 result.skip = true;
80 }
81 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
82 if let syn::Expr::Lit(syn::ExprLit {
83 lit: syn::Lit::Str(lit_str),
84 ..
85 }) = nv.value
86 {
87 result.rename = Some(lit_str.value());
88 }
89 }
90 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
91 if let syn::Expr::Lit(syn::ExprLit {
92 lit: syn::Lit::Str(lit_str),
93 ..
94 }) = nv.value
95 {
96 result.format_with = Some(lit_str.value());
97 }
98 }
99 Meta::Path(path) if path.is_ident("image") => {
100 result.image = true;
101 }
102 _ => {}
103 }
104 }
105 } else if meta_list.tokens.to_string() == "skip" {
106 result.skip = true;
108 } else if meta_list.tokens.to_string() == "image" {
109 result.image = true;
111 }
112 }
113 }
114 }
115
116 result
117}
118
119#[proc_macro_derive(ToPrompt, attributes(prompt))]
120pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
121 let input = parse_macro_input!(input as DeriveInput);
122
123 match &input.data {
125 Data::Enum(data_enum) => {
126 let enum_name = &input.ident;
128 let enum_docs = extract_doc_comments(&input.attrs);
129
130 let mut prompt_lines = Vec::new();
131
132 if !enum_docs.is_empty() {
134 prompt_lines.push(format!("{}: {}", enum_name, enum_docs));
135 } else {
136 prompt_lines.push(format!("{}:", enum_name));
137 }
138 prompt_lines.push(String::new()); prompt_lines.push("Possible values:".to_string());
140
141 for variant in &data_enum.variants {
143 let variant_name = &variant.ident;
144
145 match parse_prompt_attribute(&variant.attrs) {
147 PromptAttribute::Skip => {
148 continue;
150 }
151 PromptAttribute::Description(desc) => {
152 prompt_lines.push(format!("- {}: {}", variant_name, desc));
154 }
155 PromptAttribute::None => {
156 let variant_docs = extract_doc_comments(&variant.attrs);
158 if !variant_docs.is_empty() {
159 prompt_lines.push(format!("- {}: {}", variant_name, variant_docs));
160 } else {
161 prompt_lines.push(format!("- {}", variant_name));
162 }
163 }
164 }
165 }
166
167 let prompt_string = prompt_lines.join("\n");
168 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
169
170 let expanded = quote! {
171 impl #impl_generics llm_toolkit::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
172 fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
173 vec![llm_toolkit::prompt::PromptPart::Text(#prompt_string.to_string())]
174 }
175
176 fn to_prompt(&self) -> String {
177 #prompt_string.to_string()
178 }
179 }
180 };
181
182 TokenStream::from(expanded)
183 }
184 Data::Struct(data_struct) => {
185 let template_attr = input
187 .attrs
188 .iter()
189 .find(|attr| attr.path().is_ident("prompt"))
190 .and_then(|attr| {
191 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
193 .ok()
194 .and_then(|metas| {
195 metas.into_iter().find_map(|meta| match meta {
196 Meta::NameValue(nv) if nv.path.is_ident("template") => {
197 if let syn::Expr::Lit(expr_lit) = nv.value {
198 if let syn::Lit::Str(lit_str) = expr_lit.lit {
199 Some(lit_str.value())
200 } else {
201 None
202 }
203 } else {
204 None
205 }
206 }
207 _ => None,
208 })
209 })
210 });
211
212 let name = input.ident;
213 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
214
215 let expanded = if let Some(template_str) = template_attr {
216 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
219 &fields.named
220 } else {
221 panic!(
222 "Template prompt generation is only supported for structs with named fields."
223 );
224 };
225
226 let mut image_field_parts = Vec::new();
227 for f in fields.iter() {
228 let field_name = f.ident.as_ref().unwrap();
229 let attrs = parse_field_prompt_attrs(&f.attrs);
230
231 if attrs.image {
232 image_field_parts.push(quote! {
234 parts.extend(self.#field_name.to_prompt_parts());
235 });
236 }
237 }
238
239 quote! {
240 impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
241 fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
242 let mut parts = Vec::new();
243
244 #(#image_field_parts)*
246
247 let text = llm_toolkit::prompt::render_prompt(#template_str, self).unwrap_or_else(|e| {
249 format!("Failed to render prompt: {}", e)
250 });
251 if !text.is_empty() {
252 parts.push(llm_toolkit::prompt::PromptPart::Text(text));
253 }
254
255 parts
256 }
257
258 fn to_prompt(&self) -> String {
259 llm_toolkit::prompt::render_prompt(#template_str, self).unwrap_or_else(|e| {
260 format!("Failed to render prompt: {}", e)
261 })
262 }
263 }
264 }
265 } else {
266 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
269 &fields.named
270 } else {
271 panic!(
272 "Default prompt generation is only supported for structs with named fields."
273 );
274 };
275
276 let mut text_field_parts = Vec::new();
278 let mut image_field_parts = Vec::new();
279
280 for f in fields.iter() {
281 let field_name = f.ident.as_ref().unwrap();
282 let attrs = parse_field_prompt_attrs(&f.attrs);
283
284 if attrs.skip {
286 continue;
287 }
288
289 if attrs.image {
290 image_field_parts.push(quote! {
292 parts.extend(self.#field_name.to_prompt_parts());
293 });
294 } else {
295 let key = if let Some(rename) = attrs.rename {
301 rename
302 } else {
303 let doc_comment = extract_doc_comments(&f.attrs);
304 if !doc_comment.is_empty() {
305 doc_comment
306 } else {
307 field_name.to_string()
308 }
309 };
310
311 let value_expr = if let Some(format_with) = attrs.format_with {
313 let func_path: syn::Path =
315 syn::parse_str(&format_with).unwrap_or_else(|_| {
316 panic!("Invalid function path: {}", format_with)
317 });
318 quote! { #func_path(&self.#field_name) }
319 } else {
320 quote! { self.#field_name.to_prompt() }
321 };
322
323 text_field_parts.push(quote! {
324 text_parts.push(format!("{}: {}", #key, #value_expr));
325 });
326 }
327 }
328
329 quote! {
331 impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
332 fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
333 let mut parts = Vec::new();
334
335 #(#image_field_parts)*
337
338 let mut text_parts = Vec::new();
340 #(#text_field_parts)*
341
342 if !text_parts.is_empty() {
343 parts.push(llm_toolkit::prompt::PromptPart::Text(text_parts.join("\n")));
344 }
345
346 parts
347 }
348
349 fn to_prompt(&self) -> String {
350 let mut text_parts = Vec::new();
351 #(#text_field_parts)*
352 text_parts.join("\n")
353 }
354 }
355 }
356 };
357
358 TokenStream::from(expanded)
359 }
360 Data::Union(_) => {
361 panic!("`#[derive(ToPrompt)]` is not supported for unions");
362 }
363 }
364}