llm_toolkit_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Meta, parse_macro_input, punctuated::Punctuated};
4
5/// Extract doc comments from attributes
6fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
7    attrs
8        .iter()
9        .filter_map(|attr| {
10            if attr.path().is_ident("doc")
11                && let syn::Meta::NameValue(meta_name_value) = &attr.meta
12                && let syn::Expr::Lit(syn::ExprLit {
13                    lit: syn::Lit::Str(lit_str),
14                    ..
15                }) = &meta_name_value.value
16            {
17                return Some(lit_str.value());
18            }
19            None
20        })
21        .map(|s| s.trim().to_string())
22        .collect::<Vec<_>>()
23        .join(" ")
24}
25
26/// Result of parsing prompt attribute
27enum PromptAttribute {
28    Skip,
29    Description(String),
30    None,
31}
32
33/// Parse #[prompt(...)] attribute on enum variant
34fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
35    for attr in attrs {
36        if attr.path().is_ident("prompt") {
37            // Check for #[prompt(skip)]
38            if let Ok(meta_list) = attr.meta.require_list() {
39                let tokens = &meta_list.tokens;
40                let tokens_str = tokens.to_string();
41                if tokens_str == "skip" {
42                    return PromptAttribute::Skip;
43                }
44            }
45
46            // Check for #[prompt("description")]
47            if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
48                return PromptAttribute::Description(lit_str.value());
49            }
50        }
51    }
52    PromptAttribute::None
53}
54
55/// Parsed field-level prompt attributes
56#[derive(Debug, Default)]
57struct FieldPromptAttrs {
58    skip: bool,
59    rename: Option<String>,
60    format_with: Option<String>,
61    image: bool,
62}
63
64/// Parse #[prompt(...)] attributes for struct fields
65fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
66    let mut result = FieldPromptAttrs::default();
67
68    for attr in attrs {
69        if attr.path().is_ident("prompt") {
70            // Try to parse as meta list #[prompt(key = value, ...)]
71            if let Ok(meta_list) = attr.meta.require_list() {
72                // Parse the tokens inside the parentheses
73                if let Ok(metas) =
74                    meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
75                {
76                    for meta in metas {
77                        match meta {
78                            Meta::Path(path) if path.is_ident("skip") => {
79                                result.skip = true;
80                            }
81                            Meta::NameValue(nv) if nv.path.is_ident("rename") => {
82                                if let syn::Expr::Lit(syn::ExprLit {
83                                    lit: syn::Lit::Str(lit_str),
84                                    ..
85                                }) = nv.value
86                                {
87                                    result.rename = Some(lit_str.value());
88                                }
89                            }
90                            Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
91                                if let syn::Expr::Lit(syn::ExprLit {
92                                    lit: syn::Lit::Str(lit_str),
93                                    ..
94                                }) = nv.value
95                                {
96                                    result.format_with = Some(lit_str.value());
97                                }
98                            }
99                            Meta::Path(path) if path.is_ident("image") => {
100                                result.image = true;
101                            }
102                            _ => {}
103                        }
104                    }
105                } else if meta_list.tokens.to_string() == "skip" {
106                    // Handle simple #[prompt(skip)] case
107                    result.skip = true;
108                } else if meta_list.tokens.to_string() == "image" {
109                    // Handle simple #[prompt(image)] case
110                    result.image = true;
111                }
112            }
113        }
114    }
115
116    result
117}
118
119#[proc_macro_derive(ToPrompt, attributes(prompt))]
120pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
121    let input = parse_macro_input!(input as DeriveInput);
122
123    // Check if this is a struct or enum
124    match &input.data {
125        Data::Enum(data_enum) => {
126            // For enums, generate prompt from doc comments
127            let enum_name = &input.ident;
128            let enum_docs = extract_doc_comments(&input.attrs);
129
130            let mut prompt_lines = Vec::new();
131
132            // Add enum description
133            if !enum_docs.is_empty() {
134                prompt_lines.push(format!("{}: {}", enum_name, enum_docs));
135            } else {
136                prompt_lines.push(format!("{}:", enum_name));
137            }
138            prompt_lines.push(String::new()); // Empty line
139            prompt_lines.push("Possible values:".to_string());
140
141            // Add each variant with its documentation based on priority
142            for variant in &data_enum.variants {
143                let variant_name = &variant.ident;
144
145                // Apply fallback logic with priority
146                match parse_prompt_attribute(&variant.attrs) {
147                    PromptAttribute::Skip => {
148                        // Skip this variant completely
149                        continue;
150                    }
151                    PromptAttribute::Description(desc) => {
152                        // Use custom description from #[prompt("...")]
153                        prompt_lines.push(format!("- {}: {}", variant_name, desc));
154                    }
155                    PromptAttribute::None => {
156                        // Fall back to doc comment or just variant name
157                        let variant_docs = extract_doc_comments(&variant.attrs);
158                        if !variant_docs.is_empty() {
159                            prompt_lines.push(format!("- {}: {}", variant_name, variant_docs));
160                        } else {
161                            prompt_lines.push(format!("- {}", variant_name));
162                        }
163                    }
164                }
165            }
166
167            let prompt_string = prompt_lines.join("\n");
168            let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
169
170            let expanded = quote! {
171                impl #impl_generics llm_toolkit::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
172                    fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
173                        vec![llm_toolkit::prompt::PromptPart::Text(#prompt_string.to_string())]
174                    }
175
176                    fn to_prompt(&self) -> String {
177                        #prompt_string.to_string()
178                    }
179                }
180            };
181
182            TokenStream::from(expanded)
183        }
184        Data::Struct(data_struct) => {
185            // Check if there's a #[prompt(template = "...")] attribute
186            let template_attr = input
187                .attrs
188                .iter()
189                .find(|attr| attr.path().is_ident("prompt"))
190                .and_then(|attr| {
191                    // Try to parse the attribute arguments
192                    attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
193                        .ok()
194                        .and_then(|metas| {
195                            metas.into_iter().find_map(|meta| match meta {
196                                Meta::NameValue(nv) if nv.path.is_ident("template") => {
197                                    if let syn::Expr::Lit(expr_lit) = nv.value {
198                                        if let syn::Lit::Str(lit_str) = expr_lit.lit {
199                                            Some(lit_str.value())
200                                        } else {
201                                            None
202                                        }
203                                    } else {
204                                        None
205                                    }
206                                }
207                                _ => None,
208                            })
209                        })
210                });
211
212            let name = input.ident;
213            let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
214
215            let expanded = if let Some(template_str) = template_attr {
216                // Use template-based approach if template is provided
217                // Collect image fields separately for to_prompt_parts()
218                let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
219                    &fields.named
220                } else {
221                    panic!(
222                        "Template prompt generation is only supported for structs with named fields."
223                    );
224                };
225
226                let mut image_field_parts = Vec::new();
227                for f in fields.iter() {
228                    let field_name = f.ident.as_ref().unwrap();
229                    let attrs = parse_field_prompt_attrs(&f.attrs);
230
231                    if attrs.image {
232                        // This field is marked as an image
233                        image_field_parts.push(quote! {
234                            parts.extend(self.#field_name.to_prompt_parts());
235                        });
236                    }
237                }
238
239                quote! {
240                    impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
241                        fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
242                            let mut parts = Vec::new();
243
244                            // Add image parts first
245                            #(#image_field_parts)*
246
247                            // Add the rendered template as text
248                            let text = llm_toolkit::prompt::render_prompt(#template_str, self).unwrap_or_else(|e| {
249                                format!("Failed to render prompt: {}", e)
250                            });
251                            if !text.is_empty() {
252                                parts.push(llm_toolkit::prompt::PromptPart::Text(text));
253                            }
254
255                            parts
256                        }
257
258                        fn to_prompt(&self) -> String {
259                            llm_toolkit::prompt::render_prompt(#template_str, self).unwrap_or_else(|e| {
260                                format!("Failed to render prompt: {}", e)
261                            })
262                        }
263                    }
264                }
265            } else {
266                // Use default key-value format if no template is provided
267                // Now also generate to_prompt_parts() for multimodal support
268                let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
269                    &fields.named
270                } else {
271                    panic!(
272                        "Default prompt generation is only supported for structs with named fields."
273                    );
274                };
275
276                // Separate image fields from text fields
277                let mut text_field_parts = Vec::new();
278                let mut image_field_parts = Vec::new();
279
280                for f in fields.iter() {
281                    let field_name = f.ident.as_ref().unwrap();
282                    let attrs = parse_field_prompt_attrs(&f.attrs);
283
284                    // Skip if #[prompt(skip)] is present
285                    if attrs.skip {
286                        continue;
287                    }
288
289                    if attrs.image {
290                        // This field is marked as an image
291                        image_field_parts.push(quote! {
292                            parts.extend(self.#field_name.to_prompt_parts());
293                        });
294                    } else {
295                        // This is a regular text field
296                        // Determine the key based on priority:
297                        // 1. #[prompt(rename = "new_name")]
298                        // 2. Doc comment
299                        // 3. Field name (fallback)
300                        let key = if let Some(rename) = attrs.rename {
301                            rename
302                        } else {
303                            let doc_comment = extract_doc_comments(&f.attrs);
304                            if !doc_comment.is_empty() {
305                                doc_comment
306                            } else {
307                                field_name.to_string()
308                            }
309                        };
310
311                        // Determine the value based on format_with attribute
312                        let value_expr = if let Some(format_with) = attrs.format_with {
313                            // Parse the function path string into a syn::Path
314                            let func_path: syn::Path =
315                                syn::parse_str(&format_with).unwrap_or_else(|_| {
316                                    panic!("Invalid function path: {}", format_with)
317                                });
318                            quote! { #func_path(&self.#field_name) }
319                        } else {
320                            quote! { self.#field_name.to_prompt() }
321                        };
322
323                        text_field_parts.push(quote! {
324                            text_parts.push(format!("{}: {}", #key, #value_expr));
325                        });
326                    }
327                }
328
329                // Generate the implementation with to_prompt_parts()
330                quote! {
331                    impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
332                        fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
333                            let mut parts = Vec::new();
334
335                            // Add image parts first
336                            #(#image_field_parts)*
337
338                            // Collect text parts and add as a single text prompt part
339                            let mut text_parts = Vec::new();
340                            #(#text_field_parts)*
341
342                            if !text_parts.is_empty() {
343                                parts.push(llm_toolkit::prompt::PromptPart::Text(text_parts.join("\n")));
344                            }
345
346                            parts
347                        }
348
349                        fn to_prompt(&self) -> String {
350                            let mut text_parts = Vec::new();
351                            #(#text_field_parts)*
352                            text_parts.join("\n")
353                        }
354                    }
355                }
356            };
357
358            TokenStream::from(expanded)
359        }
360        Data::Union(_) => {
361            panic!("`#[derive(ToPrompt)]` is not supported for unions");
362        }
363    }
364}