1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Meta, parse_macro_input, punctuated::Punctuated};
4
5fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
7 attrs
8 .iter()
9 .filter_map(|attr| {
10 if attr.path().is_ident("doc")
11 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
12 && let syn::Expr::Lit(syn::ExprLit {
13 lit: syn::Lit::Str(lit_str),
14 ..
15 }) = &meta_name_value.value
16 {
17 return Some(lit_str.value());
18 }
19 None
20 })
21 .map(|s| s.trim().to_string())
22 .collect::<Vec<_>>()
23 .join(" ")
24}
25
26enum PromptAttribute {
28 Skip,
29 Description(String),
30 None,
31}
32
33fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
35 for attr in attrs {
36 if attr.path().is_ident("prompt") {
37 if let Ok(meta_list) = attr.meta.require_list() {
39 let tokens = &meta_list.tokens;
40 let tokens_str = tokens.to_string();
41 if tokens_str == "skip" {
42 return PromptAttribute::Skip;
43 }
44 }
45
46 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
48 return PromptAttribute::Description(lit_str.value());
49 }
50 }
51 }
52 PromptAttribute::None
53}
54
55#[derive(Debug, Default)]
57struct FieldPromptAttrs {
58 skip: bool,
59 rename: Option<String>,
60 format_with: Option<String>,
61 image: bool,
62}
63
64fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
66 let mut result = FieldPromptAttrs::default();
67
68 for attr in attrs {
69 if attr.path().is_ident("prompt") {
70 if let Ok(meta_list) = attr.meta.require_list() {
72 if let Ok(metas) =
74 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
75 {
76 for meta in metas {
77 match meta {
78 Meta::Path(path) if path.is_ident("skip") => {
79 result.skip = true;
80 }
81 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
82 if let syn::Expr::Lit(syn::ExprLit {
83 lit: syn::Lit::Str(lit_str),
84 ..
85 }) = nv.value
86 {
87 result.rename = Some(lit_str.value());
88 }
89 }
90 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
91 if let syn::Expr::Lit(syn::ExprLit {
92 lit: syn::Lit::Str(lit_str),
93 ..
94 }) = nv.value
95 {
96 result.format_with = Some(lit_str.value());
97 }
98 }
99 Meta::Path(path) if path.is_ident("image") => {
100 result.image = true;
101 }
102 _ => {}
103 }
104 }
105 } else if meta_list.tokens.to_string() == "skip" {
106 result.skip = true;
108 } else if meta_list.tokens.to_string() == "image" {
109 result.image = true;
111 }
112 }
113 }
114 }
115
116 result
117}
118
119#[proc_macro_derive(ToPrompt, attributes(prompt))]
120pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
121 let input = parse_macro_input!(input as DeriveInput);
122
123 match &input.data {
125 Data::Enum(data_enum) => {
126 let enum_name = &input.ident;
128 let enum_docs = extract_doc_comments(&input.attrs);
129
130 let mut prompt_lines = Vec::new();
131
132 if !enum_docs.is_empty() {
134 prompt_lines.push(format!("{}: {}", enum_name, enum_docs));
135 } else {
136 prompt_lines.push(format!("{}:", enum_name));
137 }
138 prompt_lines.push(String::new()); prompt_lines.push("Possible values:".to_string());
140
141 for variant in &data_enum.variants {
143 let variant_name = &variant.ident;
144
145 match parse_prompt_attribute(&variant.attrs) {
147 PromptAttribute::Skip => {
148 continue;
150 }
151 PromptAttribute::Description(desc) => {
152 prompt_lines.push(format!("- {}: {}", variant_name, desc));
154 }
155 PromptAttribute::None => {
156 let variant_docs = extract_doc_comments(&variant.attrs);
158 if !variant_docs.is_empty() {
159 prompt_lines.push(format!("- {}: {}", variant_name, variant_docs));
160 } else {
161 prompt_lines.push(format!("- {}", variant_name));
162 }
163 }
164 }
165 }
166
167 let prompt_string = prompt_lines.join("\n");
168 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
169
170 let expanded = quote! {
171 impl #impl_generics llm_toolkit::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
172 fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
173 vec![llm_toolkit::prompt::PromptPart::Text(#prompt_string.to_string())]
174 }
175
176 fn to_prompt(&self) -> String {
177 #prompt_string.to_string()
178 }
179 }
180 };
181
182 TokenStream::from(expanded)
183 }
184 Data::Struct(data_struct) => {
185 let template_attr = input
187 .attrs
188 .iter()
189 .find(|attr| attr.path().is_ident("prompt"))
190 .and_then(|attr| {
191 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
193 .ok()
194 .and_then(|metas| {
195 metas.into_iter().find_map(|meta| match meta {
196 Meta::NameValue(nv) if nv.path.is_ident("template") => {
197 if let syn::Expr::Lit(expr_lit) = nv.value {
198 if let syn::Lit::Str(lit_str) = expr_lit.lit {
199 Some(lit_str.value())
200 } else {
201 None
202 }
203 } else {
204 None
205 }
206 }
207 _ => None,
208 })
209 })
210 });
211
212 let name = input.ident;
213 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
214
215 let expanded = if let Some(template_str) = template_attr {
216 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
219 &fields.named
220 } else {
221 panic!(
222 "Template prompt generation is only supported for structs with named fields."
223 );
224 };
225
226 let mut image_field_parts = Vec::new();
227 for f in fields.iter() {
228 let field_name = f.ident.as_ref().unwrap();
229 let attrs = parse_field_prompt_attrs(&f.attrs);
230
231 if attrs.image {
232 image_field_parts.push(quote! {
234 parts.extend(self.#field_name.to_prompt_parts());
235 });
236 }
237 }
238
239 quote! {
240 impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
241 fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
242 let mut parts = Vec::new();
243
244 #(#image_field_parts)*
246
247 let text = llm_toolkit::prompt::render_prompt(#template_str, self).unwrap_or_else(|e| {
249 format!("Failed to render prompt: {}", e)
250 });
251 if !text.is_empty() {
252 parts.push(llm_toolkit::prompt::PromptPart::Text(text));
253 }
254
255 parts
256 }
257
258 fn to_prompt(&self) -> String {
259 llm_toolkit::prompt::render_prompt(#template_str, self).unwrap_or_else(|e| {
260 format!("Failed to render prompt: {}", e)
261 })
262 }
263 }
264 }
265 } else {
266 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
269 &fields.named
270 } else {
271 panic!(
272 "Default prompt generation is only supported for structs with named fields."
273 );
274 };
275
276 let mut text_field_parts = Vec::new();
278 let mut image_field_parts = Vec::new();
279
280 for f in fields.iter() {
281 let field_name = f.ident.as_ref().unwrap();
282 let attrs = parse_field_prompt_attrs(&f.attrs);
283
284 if attrs.skip {
286 continue;
287 }
288
289 if attrs.image {
290 image_field_parts.push(quote! {
292 parts.extend(self.#field_name.to_prompt_parts());
293 });
294 } else {
295 let key = if let Some(rename) = attrs.rename {
301 rename
302 } else {
303 let doc_comment = extract_doc_comments(&f.attrs);
304 if !doc_comment.is_empty() {
305 doc_comment
306 } else {
307 field_name.to_string()
308 }
309 };
310
311 let value_expr = if let Some(format_with) = attrs.format_with {
313 let func_path: syn::Path =
315 syn::parse_str(&format_with).unwrap_or_else(|_| {
316 panic!("Invalid function path: {}", format_with)
317 });
318 quote! { #func_path(&self.#field_name) }
319 } else {
320 quote! { self.#field_name.to_prompt() }
321 };
322
323 text_field_parts.push(quote! {
324 text_parts.push(format!("{}: {}", #key, #value_expr));
325 });
326 }
327 }
328
329 quote! {
331 impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
332 fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
333 let mut parts = Vec::new();
334
335 #(#image_field_parts)*
337
338 let mut text_parts = Vec::new();
340 #(#text_field_parts)*
341
342 if !text_parts.is_empty() {
343 parts.push(llm_toolkit::prompt::PromptPart::Text(text_parts.join("\n")));
344 }
345
346 parts
347 }
348
349 fn to_prompt(&self) -> String {
350 let mut text_parts = Vec::new();
351 #(#text_field_parts)*
352 text_parts.join("\n")
353 }
354 }
355 }
356 };
357
358 TokenStream::from(expanded)
359 }
360 Data::Union(_) => {
361 panic!("`#[derive(ToPrompt)]` is not supported for unions");
362 }
363 }
364}
365
366#[derive(Debug, Clone)]
368struct TargetInfo {
369 name: String,
370 template: Option<String>,
371 field_configs: std::collections::HashMap<String, FieldTargetConfig>,
372}
373
374#[derive(Debug, Clone, Default)]
376struct FieldTargetConfig {
377 skip: bool,
378 rename: Option<String>,
379 format_with: Option<String>,
380 image: bool,
381 include_only: bool, }
383
384fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
386 let mut configs = Vec::new();
387
388 for attr in attrs {
389 if attr.path().is_ident("prompt_for")
390 && let Ok(meta_list) = attr.meta.require_list()
391 {
392 if meta_list.tokens.to_string() == "skip" {
394 let config = FieldTargetConfig {
396 skip: true,
397 ..Default::default()
398 };
399 configs.push(("*".to_string(), config));
400 } else if let Ok(metas) =
401 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
402 {
403 let mut target_name = None;
404 let mut config = FieldTargetConfig::default();
405
406 for meta in metas {
407 match meta {
408 Meta::NameValue(nv) if nv.path.is_ident("name") => {
409 if let syn::Expr::Lit(syn::ExprLit {
410 lit: syn::Lit::Str(lit_str),
411 ..
412 }) = nv.value
413 {
414 target_name = Some(lit_str.value());
415 }
416 }
417 Meta::Path(path) if path.is_ident("skip") => {
418 config.skip = true;
419 }
420 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
421 if let syn::Expr::Lit(syn::ExprLit {
422 lit: syn::Lit::Str(lit_str),
423 ..
424 }) = nv.value
425 {
426 config.rename = Some(lit_str.value());
427 }
428 }
429 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
430 if let syn::Expr::Lit(syn::ExprLit {
431 lit: syn::Lit::Str(lit_str),
432 ..
433 }) = nv.value
434 {
435 config.format_with = Some(lit_str.value());
436 }
437 }
438 Meta::Path(path) if path.is_ident("image") => {
439 config.image = true;
440 }
441 _ => {}
442 }
443 }
444
445 if let Some(name) = target_name {
446 config.include_only = true;
447 configs.push((name, config));
448 }
449 }
450 }
451 }
452
453 configs
454}
455
456fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
458 let mut targets = Vec::new();
459
460 for attr in attrs {
461 if attr.path().is_ident("prompt_for")
462 && let Ok(meta_list) = attr.meta.require_list()
463 && let Ok(metas) =
464 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
465 {
466 let mut target_name = None;
467 let mut template = None;
468
469 for meta in metas {
470 match meta {
471 Meta::NameValue(nv) if nv.path.is_ident("name") => {
472 if let syn::Expr::Lit(syn::ExprLit {
473 lit: syn::Lit::Str(lit_str),
474 ..
475 }) = nv.value
476 {
477 target_name = Some(lit_str.value());
478 }
479 }
480 Meta::NameValue(nv) if nv.path.is_ident("template") => {
481 if let syn::Expr::Lit(syn::ExprLit {
482 lit: syn::Lit::Str(lit_str),
483 ..
484 }) = nv.value
485 {
486 template = Some(lit_str.value());
487 }
488 }
489 _ => {}
490 }
491 }
492
493 if let Some(name) = target_name {
494 targets.push(TargetInfo {
495 name,
496 template,
497 field_configs: std::collections::HashMap::new(),
498 });
499 }
500 }
501 }
502
503 targets
504}
505
506#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
507pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
508 let input = parse_macro_input!(input as DeriveInput);
509
510 let data_struct = match &input.data {
512 Data::Struct(data) => data,
513 _ => {
514 return syn::Error::new(
515 input.ident.span(),
516 "`#[derive(ToPromptSet)]` is only supported for structs",
517 )
518 .to_compile_error()
519 .into();
520 }
521 };
522
523 let fields = match &data_struct.fields {
524 syn::Fields::Named(fields) => &fields.named,
525 _ => {
526 return syn::Error::new(
527 input.ident.span(),
528 "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
529 )
530 .to_compile_error()
531 .into();
532 }
533 };
534
535 let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
537
538 for field in fields.iter() {
540 let field_name = field.ident.as_ref().unwrap().to_string();
541 let field_configs = parse_prompt_for_attrs(&field.attrs);
542
543 for (target_name, config) in field_configs {
544 if target_name == "*" {
545 for target in &mut targets {
547 target
548 .field_configs
549 .entry(field_name.clone())
550 .or_insert_with(FieldTargetConfig::default)
551 .skip = config.skip;
552 }
553 } else {
554 let target_exists = targets.iter().any(|t| t.name == target_name);
556 if !target_exists {
557 targets.push(TargetInfo {
559 name: target_name.clone(),
560 template: None,
561 field_configs: std::collections::HashMap::new(),
562 });
563 }
564
565 let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
566
567 target.field_configs.insert(field_name.clone(), config);
568 }
569 }
570 }
571
572 let mut match_arms = Vec::new();
574
575 for target in &targets {
576 let target_name = &target.name;
577
578 if let Some(template_str) = &target.template {
579 let mut image_parts = Vec::new();
581
582 for field in fields.iter() {
583 let field_name = field.ident.as_ref().unwrap();
584 let field_name_str = field_name.to_string();
585
586 if let Some(config) = target.field_configs.get(&field_name_str)
587 && config.image
588 {
589 image_parts.push(quote! {
590 parts.extend(self.#field_name.to_prompt_parts());
591 });
592 }
593 }
594
595 match_arms.push(quote! {
596 #target_name => {
597 let mut parts = Vec::new();
598
599 #(#image_parts)*
600
601 let text = llm_toolkit::prompt::render_prompt(#template_str, self)
602 .map_err(|e| llm_toolkit::prompt::PromptSetError::RenderFailed {
603 target: #target_name.to_string(),
604 source: e,
605 })?;
606
607 if !text.is_empty() {
608 parts.push(llm_toolkit::prompt::PromptPart::Text(text));
609 }
610
611 Ok(parts)
612 }
613 });
614 } else {
615 let mut text_field_parts = Vec::new();
617 let mut image_field_parts = Vec::new();
618
619 for field in fields.iter() {
620 let field_name = field.ident.as_ref().unwrap();
621 let field_name_str = field_name.to_string();
622
623 let config = target.field_configs.get(&field_name_str);
625
626 if let Some(cfg) = config
628 && cfg.skip
629 {
630 continue;
631 }
632
633 let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
637 let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
638 .iter()
639 .any(|(name, _)| name != "*");
640
641 if has_any_target_specific_config && !is_explicitly_for_this_target {
642 continue;
643 }
644
645 if let Some(cfg) = config {
646 if cfg.image {
647 image_field_parts.push(quote! {
648 parts.extend(self.#field_name.to_prompt_parts());
649 });
650 } else {
651 let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
652
653 let value_expr = if let Some(format_with) = &cfg.format_with {
654 match syn::parse_str::<syn::Path>(format_with) {
656 Ok(func_path) => quote! { #func_path(&self.#field_name) },
657 Err(_) => {
658 let error_msg = format!(
660 "Invalid function path in format_with: '{}'",
661 format_with
662 );
663 quote! {
664 compile_error!(#error_msg);
665 String::new()
666 }
667 }
668 }
669 } else {
670 quote! { self.#field_name.to_prompt() }
671 };
672
673 text_field_parts.push(quote! {
674 text_parts.push(format!("{}: {}", #key, #value_expr));
675 });
676 }
677 } else {
678 text_field_parts.push(quote! {
680 text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
681 });
682 }
683 }
684
685 match_arms.push(quote! {
686 #target_name => {
687 let mut parts = Vec::new();
688
689 #(#image_field_parts)*
690
691 let mut text_parts = Vec::new();
692 #(#text_field_parts)*
693
694 if !text_parts.is_empty() {
695 parts.push(llm_toolkit::prompt::PromptPart::Text(text_parts.join("\n")));
696 }
697
698 Ok(parts)
699 }
700 });
701 }
702 }
703
704 let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
706
707 match_arms.push(quote! {
709 _ => {
710 let available = vec![#(#target_names.to_string()),*];
711 Err(llm_toolkit::prompt::PromptSetError::TargetNotFound {
712 target: target.to_string(),
713 available,
714 })
715 }
716 });
717
718 let struct_name = &input.ident;
719 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
720
721 let expanded = quote! {
722 impl #impl_generics llm_toolkit::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
723 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<llm_toolkit::prompt::PromptPart>, llm_toolkit::prompt::PromptSetError> {
724 match target {
725 #(#match_arms)*
726 }
727 }
728 }
729 };
730
731 TokenStream::from(expanded)
732}