llm_toolkit_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Meta, parse_macro_input, punctuated::Punctuated};
4
5/// Extract doc comments from attributes
6fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
7    attrs
8        .iter()
9        .filter_map(|attr| {
10            if attr.path().is_ident("doc") {
11                if let syn::Meta::NameValue(meta_name_value) = &attr.meta {
12                    if let syn::Expr::Lit(syn::ExprLit {
13                        lit: syn::Lit::Str(lit_str),
14                        ..
15                    }) = &meta_name_value.value
16                    {
17                        return Some(lit_str.value());
18                    }
19                }
20            }
21            None
22        })
23        .map(|s| s.trim().to_string())
24        .collect::<Vec<_>>()
25        .join(" ")
26}
27
28/// Result of parsing prompt attribute
29enum PromptAttribute {
30    Skip,
31    Description(String),
32    None,
33}
34
35/// Parse #[prompt(...)] attribute on enum variant
36fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
37    for attr in attrs {
38        if attr.path().is_ident("prompt") {
39            // Check for #[prompt(skip)]
40            if let Ok(meta_list) = attr.meta.require_list() {
41                let tokens = &meta_list.tokens;
42                let tokens_str = tokens.to_string();
43                if tokens_str == "skip" {
44                    return PromptAttribute::Skip;
45                }
46            }
47
48            // Check for #[prompt("description")]
49            if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
50                return PromptAttribute::Description(lit_str.value());
51            }
52        }
53    }
54    PromptAttribute::None
55}
56
57#[proc_macro_derive(ToPrompt, attributes(prompt))]
58pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
59    let input = parse_macro_input!(input as DeriveInput);
60
61    // Check if this is a struct or enum
62    match &input.data {
63        Data::Enum(data_enum) => {
64            // For enums, generate prompt from doc comments
65            let enum_name = &input.ident;
66            let enum_docs = extract_doc_comments(&input.attrs);
67
68            let mut prompt_lines = Vec::new();
69
70            // Add enum description
71            if !enum_docs.is_empty() {
72                prompt_lines.push(format!("{}: {}", enum_name, enum_docs));
73            } else {
74                prompt_lines.push(format!("{}:", enum_name));
75            }
76            prompt_lines.push(String::new()); // Empty line
77            prompt_lines.push("Possible values:".to_string());
78
79            // Add each variant with its documentation based on priority
80            for variant in &data_enum.variants {
81                let variant_name = &variant.ident;
82
83                // Apply fallback logic with priority
84                match parse_prompt_attribute(&variant.attrs) {
85                    PromptAttribute::Skip => {
86                        // Skip this variant completely
87                        continue;
88                    }
89                    PromptAttribute::Description(desc) => {
90                        // Use custom description from #[prompt("...")]
91                        prompt_lines.push(format!("- {}: {}", variant_name, desc));
92                    }
93                    PromptAttribute::None => {
94                        // Fall back to doc comment or just variant name
95                        let variant_docs = extract_doc_comments(&variant.attrs);
96                        if !variant_docs.is_empty() {
97                            prompt_lines.push(format!("- {}: {}", variant_name, variant_docs));
98                        } else {
99                            prompt_lines.push(format!("- {}", variant_name));
100                        }
101                    }
102                }
103            }
104
105            let prompt_string = prompt_lines.join("\n");
106            let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
107
108            let expanded = quote! {
109                impl #impl_generics llm_toolkit::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
110                    fn to_prompt(&self) -> String {
111                        #prompt_string.to_string()
112                    }
113                }
114            };
115
116            TokenStream::from(expanded)
117        }
118        Data::Struct(_) => {
119            // For structs, use the existing template-based approach
120            let attr = input
121                .attrs
122                .iter()
123                .find(|attr| attr.path().is_ident("prompt"))
124                .expect("`#[derive(ToPrompt)]` on structs requires a `#[prompt(...)]` attribute.");
125
126            // `syn::Attribute::parse_args_with` を使って属性をパースする
127            let name_value = attr
128                .parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
129                .expect("Failed to parse `prompt` attribute arguments")
130                .into_iter()
131                .find_map(|meta| match meta {
132                    Meta::NameValue(nv) if nv.path.is_ident("template") => Some(nv),
133                    _ => None,
134                })
135                .expect("`#[prompt(...)]` must contain `template = \"...\"`");
136
137            let template_str = if let syn::Expr::Lit(expr_lit) = name_value.value {
138                if let syn::Lit::Str(lit_str) = expr_lit.lit {
139                    lit_str.value()
140                } else {
141                    panic!("'template' attribute value must be a string literal.");
142                }
143            } else {
144                panic!("'template' attribute must have a literal value.");
145            };
146
147            let name = input.ident;
148            let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
149
150            let expanded = quote! {
151                impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
152                    fn to_prompt(&self) -> String {
153                        llm_toolkit::prompt::render_prompt(#template_str, self).unwrap_or_else(|e| {
154                            format!("Failed to render prompt: {}", e)
155                        })
156                    }
157                }
158            };
159
160            TokenStream::from(expanded)
161        }
162        Data::Union(_) => {
163            panic!("`#[derive(ToPrompt)]` is not supported for unions");
164        }
165    }
166}