llm_toolkit_macros/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Meta, parse_macro_input, punctuated::Punctuated};
4
5fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
7 attrs
8 .iter()
9 .filter_map(|attr| {
10 if attr.path().is_ident("doc")
11 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
12 && let syn::Expr::Lit(syn::ExprLit {
13 lit: syn::Lit::Str(lit_str),
14 ..
15 }) = &meta_name_value.value
16 {
17 return Some(lit_str.value());
18 }
19 None
20 })
21 .map(|s| s.trim().to_string())
22 .collect::<Vec<_>>()
23 .join(" ")
24}
25
26enum PromptAttribute {
28 Skip,
29 Description(String),
30 None,
31}
32
33fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
35 for attr in attrs {
36 if attr.path().is_ident("prompt") {
37 if let Ok(meta_list) = attr.meta.require_list() {
39 let tokens = &meta_list.tokens;
40 let tokens_str = tokens.to_string();
41 if tokens_str == "skip" {
42 return PromptAttribute::Skip;
43 }
44 }
45
46 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
48 return PromptAttribute::Description(lit_str.value());
49 }
50 }
51 }
52 PromptAttribute::None
53}
54
55#[proc_macro_derive(ToPrompt, attributes(prompt))]
56pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
57 let input = parse_macro_input!(input as DeriveInput);
58
59 match &input.data {
61 Data::Enum(data_enum) => {
62 let enum_name = &input.ident;
64 let enum_docs = extract_doc_comments(&input.attrs);
65
66 let mut prompt_lines = Vec::new();
67
68 if !enum_docs.is_empty() {
70 prompt_lines.push(format!("{}: {}", enum_name, enum_docs));
71 } else {
72 prompt_lines.push(format!("{}:", enum_name));
73 }
74 prompt_lines.push(String::new()); prompt_lines.push("Possible values:".to_string());
76
77 for variant in &data_enum.variants {
79 let variant_name = &variant.ident;
80
81 match parse_prompt_attribute(&variant.attrs) {
83 PromptAttribute::Skip => {
84 continue;
86 }
87 PromptAttribute::Description(desc) => {
88 prompt_lines.push(format!("- {}: {}", variant_name, desc));
90 }
91 PromptAttribute::None => {
92 let variant_docs = extract_doc_comments(&variant.attrs);
94 if !variant_docs.is_empty() {
95 prompt_lines.push(format!("- {}: {}", variant_name, variant_docs));
96 } else {
97 prompt_lines.push(format!("- {}", variant_name));
98 }
99 }
100 }
101 }
102
103 let prompt_string = prompt_lines.join("\n");
104 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
105
106 let expanded = quote! {
107 impl #impl_generics llm_toolkit::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
108 fn to_prompt(&self) -> String {
109 #prompt_string.to_string()
110 }
111 }
112 };
113
114 TokenStream::from(expanded)
115 }
116 Data::Struct(_) => {
117 let attr = input
119 .attrs
120 .iter()
121 .find(|attr| attr.path().is_ident("prompt"))
122 .expect("`#[derive(ToPrompt)]` on structs requires a `#[prompt(...)]` attribute.");
123
124 let name_value = attr
126 .parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
127 .expect("Failed to parse `prompt` attribute arguments")
128 .into_iter()
129 .find_map(|meta| match meta {
130 Meta::NameValue(nv) if nv.path.is_ident("template") => Some(nv),
131 _ => None,
132 })
133 .expect("`#[prompt(...)]` must contain `template = \"...\"`");
134
135 let template_str = if let syn::Expr::Lit(expr_lit) = name_value.value {
136 if let syn::Lit::Str(lit_str) = expr_lit.lit {
137 lit_str.value()
138 } else {
139 panic!("'template' attribute value must be a string literal.");
140 }
141 } else {
142 panic!("'template' attribute must have a literal value.");
143 };
144
145 let name = input.ident;
146 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
147
148 let expanded = quote! {
149 impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
150 fn to_prompt(&self) -> String {
151 llm_toolkit::prompt::render_prompt(#template_str, self).unwrap_or_else(|e| {
152 format!("Failed to render prompt: {}", e)
153 })
154 }
155 }
156 };
157
158 TokenStream::from(expanded)
159 }
160 Data::Union(_) => {
161 panic!("`#[derive(ToPrompt)]` is not supported for unions");
162 }
163 }
164}