llm_toolkit_macros/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Meta, parse_macro_input, punctuated::Punctuated};
4
5fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
7 attrs
8 .iter()
9 .filter_map(|attr| {
10 if attr.path().is_ident("doc")
11 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
12 && let syn::Expr::Lit(syn::ExprLit {
13 lit: syn::Lit::Str(lit_str),
14 ..
15 }) = &meta_name_value.value
16 {
17 return Some(lit_str.value());
18 }
19 None
20 })
21 .map(|s| s.trim().to_string())
22 .collect::<Vec<_>>()
23 .join(" ")
24}
25
26enum PromptAttribute {
28 Skip,
29 Description(String),
30 None,
31}
32
33fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
35 for attr in attrs {
36 if attr.path().is_ident("prompt") {
37 if let Ok(meta_list) = attr.meta.require_list() {
39 let tokens = &meta_list.tokens;
40 let tokens_str = tokens.to_string();
41 if tokens_str == "skip" {
42 return PromptAttribute::Skip;
43 }
44 }
45
46 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
48 return PromptAttribute::Description(lit_str.value());
49 }
50 }
51 }
52 PromptAttribute::None
53}
54
55#[derive(Debug, Default)]
57struct FieldPromptAttrs {
58 skip: bool,
59 rename: Option<String>,
60 format_with: Option<String>,
61}
62
63fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
65 let mut result = FieldPromptAttrs::default();
66
67 for attr in attrs {
68 if attr.path().is_ident("prompt") {
69 if let Ok(meta_list) = attr.meta.require_list() {
71 if let Ok(metas) =
73 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
74 {
75 for meta in metas {
76 match meta {
77 Meta::Path(path) if path.is_ident("skip") => {
78 result.skip = true;
79 }
80 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
81 if let syn::Expr::Lit(syn::ExprLit {
82 lit: syn::Lit::Str(lit_str),
83 ..
84 }) = nv.value
85 {
86 result.rename = Some(lit_str.value());
87 }
88 }
89 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
90 if let syn::Expr::Lit(syn::ExprLit {
91 lit: syn::Lit::Str(lit_str),
92 ..
93 }) = nv.value
94 {
95 result.format_with = Some(lit_str.value());
96 }
97 }
98 _ => {}
99 }
100 }
101 } else if meta_list.tokens.to_string() == "skip" {
102 result.skip = true;
104 }
105 }
106 }
107 }
108
109 result
110}
111
112#[proc_macro_derive(ToPrompt, attributes(prompt))]
113pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
114 let input = parse_macro_input!(input as DeriveInput);
115
116 match &input.data {
118 Data::Enum(data_enum) => {
119 let enum_name = &input.ident;
121 let enum_docs = extract_doc_comments(&input.attrs);
122
123 let mut prompt_lines = Vec::new();
124
125 if !enum_docs.is_empty() {
127 prompt_lines.push(format!("{}: {}", enum_name, enum_docs));
128 } else {
129 prompt_lines.push(format!("{}:", enum_name));
130 }
131 prompt_lines.push(String::new()); prompt_lines.push("Possible values:".to_string());
133
134 for variant in &data_enum.variants {
136 let variant_name = &variant.ident;
137
138 match parse_prompt_attribute(&variant.attrs) {
140 PromptAttribute::Skip => {
141 continue;
143 }
144 PromptAttribute::Description(desc) => {
145 prompt_lines.push(format!("- {}: {}", variant_name, desc));
147 }
148 PromptAttribute::None => {
149 let variant_docs = extract_doc_comments(&variant.attrs);
151 if !variant_docs.is_empty() {
152 prompt_lines.push(format!("- {}: {}", variant_name, variant_docs));
153 } else {
154 prompt_lines.push(format!("- {}", variant_name));
155 }
156 }
157 }
158 }
159
160 let prompt_string = prompt_lines.join("\n");
161 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
162
163 let expanded = quote! {
164 impl #impl_generics llm_toolkit::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
165 fn to_prompt(&self) -> String {
166 #prompt_string.to_string()
167 }
168 }
169 };
170
171 TokenStream::from(expanded)
172 }
173 Data::Struct(data_struct) => {
174 let template_attr = input
176 .attrs
177 .iter()
178 .find(|attr| attr.path().is_ident("prompt"))
179 .and_then(|attr| {
180 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
182 .ok()
183 .and_then(|metas| {
184 metas.into_iter().find_map(|meta| match meta {
185 Meta::NameValue(nv) if nv.path.is_ident("template") => {
186 if let syn::Expr::Lit(expr_lit) = nv.value {
187 if let syn::Lit::Str(lit_str) = expr_lit.lit {
188 Some(lit_str.value())
189 } else {
190 None
191 }
192 } else {
193 None
194 }
195 }
196 _ => None,
197 })
198 })
199 });
200
201 let name = input.ident;
202 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
203
204 let expanded = if let Some(template_str) = template_attr {
205 quote! {
207 impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
208 fn to_prompt(&self) -> String {
209 llm_toolkit::prompt::render_prompt(#template_str, self).unwrap_or_else(|e| {
210 format!("Failed to render prompt: {}", e)
211 })
212 }
213 }
214 }
215 } else {
216 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
218 &fields.named
219 } else {
220 panic!(
221 "Default prompt generation is only supported for structs with named fields."
222 );
223 };
224
225 let field_prompts: Vec<_> = fields
226 .iter()
227 .filter_map(|f| {
228 let field_name = f.ident.as_ref().unwrap();
229 let attrs = parse_field_prompt_attrs(&f.attrs);
230
231 if attrs.skip {
233 return None;
234 }
235
236 let key = if let Some(rename) = attrs.rename {
241 rename
242 } else {
243 let doc_comment = extract_doc_comments(&f.attrs);
244 if !doc_comment.is_empty() {
245 doc_comment
246 } else {
247 field_name.to_string()
248 }
249 };
250
251 let value_expr = if let Some(format_with) = attrs.format_with {
253 let func_path: syn::Path =
255 syn::parse_str(&format_with).unwrap_or_else(|_| {
256 panic!("Invalid function path: {}", format_with)
257 });
258 quote! { #func_path(&self.#field_name) }
259 } else {
260 quote! { self.#field_name.to_prompt() }
261 };
262
263 Some(quote! {
264 format!("{}: {}", #key, #value_expr)
265 })
266 })
267 .collect();
268
269 quote! {
270 impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
271 fn to_prompt(&self) -> String {
272 let mut parts = Vec::new();
273 #(
274 parts.push(#field_prompts);
275 )*
276 parts.join("\n")
277 }
278 }
279 }
280 };
281
282 TokenStream::from(expanded)
283 }
284 Data::Union(_) => {
285 panic!("`#[derive(ToPrompt)]` is not supported for unions");
286 }
287 }
288}