llm_toolkit_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Meta, parse_macro_input, punctuated::Punctuated};
4
5/// Extract doc comments from attributes
6fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
7    attrs
8        .iter()
9        .filter_map(|attr| {
10            if attr.path().is_ident("doc")
11                && let syn::Meta::NameValue(meta_name_value) = &attr.meta
12                && let syn::Expr::Lit(syn::ExprLit {
13                    lit: syn::Lit::Str(lit_str),
14                    ..
15                }) = &meta_name_value.value
16            {
17                return Some(lit_str.value());
18            }
19            None
20        })
21        .map(|s| s.trim().to_string())
22        .collect::<Vec<_>>()
23        .join(" ")
24}
25
26/// Result of parsing prompt attribute
27enum PromptAttribute {
28    Skip,
29    Description(String),
30    None,
31}
32
33/// Parse #[prompt(...)] attribute on enum variant
34fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
35    for attr in attrs {
36        if attr.path().is_ident("prompt") {
37            // Check for #[prompt(skip)]
38            if let Ok(meta_list) = attr.meta.require_list() {
39                let tokens = &meta_list.tokens;
40                let tokens_str = tokens.to_string();
41                if tokens_str == "skip" {
42                    return PromptAttribute::Skip;
43                }
44            }
45
46            // Check for #[prompt("description")]
47            if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
48                return PromptAttribute::Description(lit_str.value());
49            }
50        }
51    }
52    PromptAttribute::None
53}
54
55/// Parsed field-level prompt attributes
56#[derive(Debug, Default)]
57struct FieldPromptAttrs {
58    skip: bool,
59    rename: Option<String>,
60    format_with: Option<String>,
61}
62
63/// Parse #[prompt(...)] attributes for struct fields
64fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
65    let mut result = FieldPromptAttrs::default();
66
67    for attr in attrs {
68        if attr.path().is_ident("prompt") {
69            // Try to parse as meta list #[prompt(key = value, ...)]
70            if let Ok(meta_list) = attr.meta.require_list() {
71                // Parse the tokens inside the parentheses
72                if let Ok(metas) =
73                    meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
74                {
75                    for meta in metas {
76                        match meta {
77                            Meta::Path(path) if path.is_ident("skip") => {
78                                result.skip = true;
79                            }
80                            Meta::NameValue(nv) if nv.path.is_ident("rename") => {
81                                if let syn::Expr::Lit(syn::ExprLit {
82                                    lit: syn::Lit::Str(lit_str),
83                                    ..
84                                }) = nv.value
85                                {
86                                    result.rename = Some(lit_str.value());
87                                }
88                            }
89                            Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
90                                if let syn::Expr::Lit(syn::ExprLit {
91                                    lit: syn::Lit::Str(lit_str),
92                                    ..
93                                }) = nv.value
94                                {
95                                    result.format_with = Some(lit_str.value());
96                                }
97                            }
98                            _ => {}
99                        }
100                    }
101                } else if meta_list.tokens.to_string() == "skip" {
102                    // Handle simple #[prompt(skip)] case
103                    result.skip = true;
104                }
105            }
106        }
107    }
108
109    result
110}
111
112#[proc_macro_derive(ToPrompt, attributes(prompt))]
113pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
114    let input = parse_macro_input!(input as DeriveInput);
115
116    // Check if this is a struct or enum
117    match &input.data {
118        Data::Enum(data_enum) => {
119            // For enums, generate prompt from doc comments
120            let enum_name = &input.ident;
121            let enum_docs = extract_doc_comments(&input.attrs);
122
123            let mut prompt_lines = Vec::new();
124
125            // Add enum description
126            if !enum_docs.is_empty() {
127                prompt_lines.push(format!("{}: {}", enum_name, enum_docs));
128            } else {
129                prompt_lines.push(format!("{}:", enum_name));
130            }
131            prompt_lines.push(String::new()); // Empty line
132            prompt_lines.push("Possible values:".to_string());
133
134            // Add each variant with its documentation based on priority
135            for variant in &data_enum.variants {
136                let variant_name = &variant.ident;
137
138                // Apply fallback logic with priority
139                match parse_prompt_attribute(&variant.attrs) {
140                    PromptAttribute::Skip => {
141                        // Skip this variant completely
142                        continue;
143                    }
144                    PromptAttribute::Description(desc) => {
145                        // Use custom description from #[prompt("...")]
146                        prompt_lines.push(format!("- {}: {}", variant_name, desc));
147                    }
148                    PromptAttribute::None => {
149                        // Fall back to doc comment or just variant name
150                        let variant_docs = extract_doc_comments(&variant.attrs);
151                        if !variant_docs.is_empty() {
152                            prompt_lines.push(format!("- {}: {}", variant_name, variant_docs));
153                        } else {
154                            prompt_lines.push(format!("- {}", variant_name));
155                        }
156                    }
157                }
158            }
159
160            let prompt_string = prompt_lines.join("\n");
161            let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
162
163            let expanded = quote! {
164                impl #impl_generics llm_toolkit::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
165                    fn to_prompt(&self) -> String {
166                        #prompt_string.to_string()
167                    }
168                }
169            };
170
171            TokenStream::from(expanded)
172        }
173        Data::Struct(data_struct) => {
174            // Check if there's a #[prompt(template = "...")] attribute
175            let template_attr = input
176                .attrs
177                .iter()
178                .find(|attr| attr.path().is_ident("prompt"))
179                .and_then(|attr| {
180                    // Try to parse the attribute arguments
181                    attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
182                        .ok()
183                        .and_then(|metas| {
184                            metas.into_iter().find_map(|meta| match meta {
185                                Meta::NameValue(nv) if nv.path.is_ident("template") => {
186                                    if let syn::Expr::Lit(expr_lit) = nv.value {
187                                        if let syn::Lit::Str(lit_str) = expr_lit.lit {
188                                            Some(lit_str.value())
189                                        } else {
190                                            None
191                                        }
192                                    } else {
193                                        None
194                                    }
195                                }
196                                _ => None,
197                            })
198                        })
199                });
200
201            let name = input.ident;
202            let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
203
204            let expanded = if let Some(template_str) = template_attr {
205                // Use template-based approach if template is provided
206                quote! {
207                    impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
208                        fn to_prompt(&self) -> String {
209                            llm_toolkit::prompt::render_prompt(#template_str, self).unwrap_or_else(|e| {
210                                format!("Failed to render prompt: {}", e)
211                            })
212                        }
213                    }
214                }
215            } else {
216                // Use default key-value format if no template is provided
217                let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
218                    &fields.named
219                } else {
220                    panic!(
221                        "Default prompt generation is only supported for structs with named fields."
222                    );
223                };
224
225                let field_prompts: Vec<_> = fields
226                    .iter()
227                    .filter_map(|f| {
228                        let field_name = f.ident.as_ref().unwrap();
229                        let attrs = parse_field_prompt_attrs(&f.attrs);
230
231                        // Skip if #[prompt(skip)] is present
232                        if attrs.skip {
233                            return None;
234                        }
235
236                        // Determine the key based on priority:
237                        // 1. #[prompt(rename = "new_name")]
238                        // 2. Doc comment
239                        // 3. Field name (fallback)
240                        let key = if let Some(rename) = attrs.rename {
241                            rename
242                        } else {
243                            let doc_comment = extract_doc_comments(&f.attrs);
244                            if !doc_comment.is_empty() {
245                                doc_comment
246                            } else {
247                                field_name.to_string()
248                            }
249                        };
250
251                        // Determine the value based on format_with attribute
252                        let value_expr = if let Some(format_with) = attrs.format_with {
253                            // Parse the function path string into a syn::Path
254                            let func_path: syn::Path =
255                                syn::parse_str(&format_with).unwrap_or_else(|_| {
256                                    panic!("Invalid function path: {}", format_with)
257                                });
258                            quote! { #func_path(&self.#field_name) }
259                        } else {
260                            quote! { self.#field_name.to_prompt() }
261                        };
262
263                        Some(quote! {
264                            format!("{}: {}", #key, #value_expr)
265                        })
266                    })
267                    .collect();
268
269                quote! {
270                    impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
271                        fn to_prompt(&self) -> String {
272                            let mut parts = Vec::new();
273                            #(
274                                parts.push(#field_prompts);
275                            )*
276                            parts.join("\n")
277                        }
278                    }
279                }
280            };
281
282            TokenStream::from(expanded)
283        }
284        Data::Union(_) => {
285            panic!("`#[derive(ToPrompt)]` is not supported for unions");
286        }
287    }
288}