llm_toolkit_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Meta, parse_macro_input, punctuated::Punctuated};
4
5/// Extract doc comments from attributes
6fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
7    attrs
8        .iter()
9        .filter_map(|attr| {
10            if attr.path().is_ident("doc")
11                && let syn::Meta::NameValue(meta_name_value) = &attr.meta
12                && let syn::Expr::Lit(syn::ExprLit {
13                    lit: syn::Lit::Str(lit_str),
14                    ..
15                }) = &meta_name_value.value
16            {
17                return Some(lit_str.value());
18            }
19            None
20        })
21        .map(|s| s.trim().to_string())
22        .collect::<Vec<_>>()
23        .join(" ")
24}
25
26/// Result of parsing prompt attribute
27enum PromptAttribute {
28    Skip,
29    Description(String),
30    None,
31}
32
33/// Parse #[prompt(...)] attribute on enum variant
34fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
35    for attr in attrs {
36        if attr.path().is_ident("prompt") {
37            // Check for #[prompt(skip)]
38            if let Ok(meta_list) = attr.meta.require_list() {
39                let tokens = &meta_list.tokens;
40                let tokens_str = tokens.to_string();
41                if tokens_str == "skip" {
42                    return PromptAttribute::Skip;
43                }
44            }
45
46            // Check for #[prompt("description")]
47            if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
48                return PromptAttribute::Description(lit_str.value());
49            }
50        }
51    }
52    PromptAttribute::None
53}
54
55#[proc_macro_derive(ToPrompt, attributes(prompt))]
56pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
57    let input = parse_macro_input!(input as DeriveInput);
58
59    // Check if this is a struct or enum
60    match &input.data {
61        Data::Enum(data_enum) => {
62            // For enums, generate prompt from doc comments
63            let enum_name = &input.ident;
64            let enum_docs = extract_doc_comments(&input.attrs);
65
66            let mut prompt_lines = Vec::new();
67
68            // Add enum description
69            if !enum_docs.is_empty() {
70                prompt_lines.push(format!("{}: {}", enum_name, enum_docs));
71            } else {
72                prompt_lines.push(format!("{}:", enum_name));
73            }
74            prompt_lines.push(String::new()); // Empty line
75            prompt_lines.push("Possible values:".to_string());
76
77            // Add each variant with its documentation based on priority
78            for variant in &data_enum.variants {
79                let variant_name = &variant.ident;
80
81                // Apply fallback logic with priority
82                match parse_prompt_attribute(&variant.attrs) {
83                    PromptAttribute::Skip => {
84                        // Skip this variant completely
85                        continue;
86                    }
87                    PromptAttribute::Description(desc) => {
88                        // Use custom description from #[prompt("...")]
89                        prompt_lines.push(format!("- {}: {}", variant_name, desc));
90                    }
91                    PromptAttribute::None => {
92                        // Fall back to doc comment or just variant name
93                        let variant_docs = extract_doc_comments(&variant.attrs);
94                        if !variant_docs.is_empty() {
95                            prompt_lines.push(format!("- {}: {}", variant_name, variant_docs));
96                        } else {
97                            prompt_lines.push(format!("- {}", variant_name));
98                        }
99                    }
100                }
101            }
102
103            let prompt_string = prompt_lines.join("\n");
104            let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
105
106            let expanded = quote! {
107                impl #impl_generics llm_toolkit::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
108                    fn to_prompt(&self) -> String {
109                        #prompt_string.to_string()
110                    }
111                }
112            };
113
114            TokenStream::from(expanded)
115        }
116        Data::Struct(_) => {
117            // For structs, use the existing template-based approach
118            let attr = input
119                .attrs
120                .iter()
121                .find(|attr| attr.path().is_ident("prompt"))
122                .expect("`#[derive(ToPrompt)]` on structs requires a `#[prompt(...)]` attribute.");
123
124            // `syn::Attribute::parse_args_with` を使って属性をパースする
125            let name_value = attr
126                .parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
127                .expect("Failed to parse `prompt` attribute arguments")
128                .into_iter()
129                .find_map(|meta| match meta {
130                    Meta::NameValue(nv) if nv.path.is_ident("template") => Some(nv),
131                    _ => None,
132                })
133                .expect("`#[prompt(...)]` must contain `template = \"...\"`");
134
135            let template_str = if let syn::Expr::Lit(expr_lit) = name_value.value {
136                if let syn::Lit::Str(lit_str) = expr_lit.lit {
137                    lit_str.value()
138                } else {
139                    panic!("'template' attribute value must be a string literal.");
140                }
141            } else {
142                panic!("'template' attribute must have a literal value.");
143            };
144
145            let name = input.ident;
146            let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
147
148            let expanded = quote! {
149                impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
150                    fn to_prompt(&self) -> String {
151                        llm_toolkit::prompt::render_prompt(#template_str, self).unwrap_or_else(|e| {
152                            format!("Failed to render prompt: {}", e)
153                        })
154                    }
155                }
156            };
157
158            TokenStream::from(expanded)
159        }
160        Data::Union(_) => {
161            panic!("`#[derive(ToPrompt)]` is not supported for unions");
162        }
163    }
164}