1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Meta, parse_macro_input, punctuated::Punctuated};
4
5fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
7 attrs
8 .iter()
9 .filter_map(|attr| {
10 if attr.path().is_ident("doc")
11 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
12 && let syn::Expr::Lit(syn::ExprLit {
13 lit: syn::Lit::Str(lit_str),
14 ..
15 }) = &meta_name_value.value
16 {
17 return Some(lit_str.value());
18 }
19 None
20 })
21 .map(|s| s.trim().to_string())
22 .collect::<Vec<_>>()
23 .join(" ")
24}
25
26enum PromptAttribute {
28 Skip,
29 Description(String),
30 None,
31}
32
33fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
35 for attr in attrs {
36 if attr.path().is_ident("prompt") {
37 if let Ok(meta_list) = attr.meta.require_list() {
39 let tokens = &meta_list.tokens;
40 let tokens_str = tokens.to_string();
41 if tokens_str == "skip" {
42 return PromptAttribute::Skip;
43 }
44 }
45
46 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
48 return PromptAttribute::Description(lit_str.value());
49 }
50 }
51 }
52 PromptAttribute::None
53}
54
55#[derive(Debug, Default)]
57struct FieldPromptAttrs {
58 skip: bool,
59 rename: Option<String>,
60 format_with: Option<String>,
61 image: bool,
62}
63
64fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
66 let mut result = FieldPromptAttrs::default();
67
68 for attr in attrs {
69 if attr.path().is_ident("prompt") {
70 if let Ok(meta_list) = attr.meta.require_list() {
72 if let Ok(metas) =
74 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
75 {
76 for meta in metas {
77 match meta {
78 Meta::Path(path) if path.is_ident("skip") => {
79 result.skip = true;
80 }
81 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
82 if let syn::Expr::Lit(syn::ExprLit {
83 lit: syn::Lit::Str(lit_str),
84 ..
85 }) = nv.value
86 {
87 result.rename = Some(lit_str.value());
88 }
89 }
90 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
91 if let syn::Expr::Lit(syn::ExprLit {
92 lit: syn::Lit::Str(lit_str),
93 ..
94 }) = nv.value
95 {
96 result.format_with = Some(lit_str.value());
97 }
98 }
99 Meta::Path(path) if path.is_ident("image") => {
100 result.image = true;
101 }
102 _ => {}
103 }
104 }
105 } else if meta_list.tokens.to_string() == "skip" {
106 result.skip = true;
108 } else if meta_list.tokens.to_string() == "image" {
109 result.image = true;
111 }
112 }
113 }
114 }
115
116 result
117}
118
119#[proc_macro_derive(ToPrompt, attributes(prompt))]
162pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
163 let input = parse_macro_input!(input as DeriveInput);
164
165 match &input.data {
167 Data::Enum(data_enum) => {
168 let enum_name = &input.ident;
170 let enum_docs = extract_doc_comments(&input.attrs);
171
172 let mut prompt_lines = Vec::new();
173
174 if !enum_docs.is_empty() {
176 prompt_lines.push(format!("{}: {}", enum_name, enum_docs));
177 } else {
178 prompt_lines.push(format!("{}:", enum_name));
179 }
180 prompt_lines.push(String::new()); prompt_lines.push("Possible values:".to_string());
182
183 for variant in &data_enum.variants {
185 let variant_name = &variant.ident;
186
187 match parse_prompt_attribute(&variant.attrs) {
189 PromptAttribute::Skip => {
190 continue;
192 }
193 PromptAttribute::Description(desc) => {
194 prompt_lines.push(format!("- {}: {}", variant_name, desc));
196 }
197 PromptAttribute::None => {
198 let variant_docs = extract_doc_comments(&variant.attrs);
200 if !variant_docs.is_empty() {
201 prompt_lines.push(format!("- {}: {}", variant_name, variant_docs));
202 } else {
203 prompt_lines.push(format!("- {}", variant_name));
204 }
205 }
206 }
207 }
208
209 let prompt_string = prompt_lines.join("\n");
210 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
211
212 let expanded = quote! {
213 impl #impl_generics llm_toolkit::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
214 fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
215 vec![llm_toolkit::prompt::PromptPart::Text(#prompt_string.to_string())]
216 }
217
218 fn to_prompt(&self) -> String {
219 #prompt_string.to_string()
220 }
221 }
222 };
223
224 TokenStream::from(expanded)
225 }
226 Data::Struct(data_struct) => {
227 let template_attr = input
229 .attrs
230 .iter()
231 .find(|attr| attr.path().is_ident("prompt"))
232 .and_then(|attr| {
233 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
235 .ok()
236 .and_then(|metas| {
237 metas.into_iter().find_map(|meta| match meta {
238 Meta::NameValue(nv) if nv.path.is_ident("template") => {
239 if let syn::Expr::Lit(expr_lit) = nv.value {
240 if let syn::Lit::Str(lit_str) = expr_lit.lit {
241 Some(lit_str.value())
242 } else {
243 None
244 }
245 } else {
246 None
247 }
248 }
249 _ => None,
250 })
251 })
252 });
253
254 let name = input.ident;
255 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
256
257 let expanded = if let Some(template_str) = template_attr {
258 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
261 &fields.named
262 } else {
263 panic!(
264 "Template prompt generation is only supported for structs with named fields."
265 );
266 };
267
268 let mut image_field_parts = Vec::new();
269 for f in fields.iter() {
270 let field_name = f.ident.as_ref().unwrap();
271 let attrs = parse_field_prompt_attrs(&f.attrs);
272
273 if attrs.image {
274 image_field_parts.push(quote! {
276 parts.extend(self.#field_name.to_prompt_parts());
277 });
278 }
279 }
280
281 quote! {
282 impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
283 fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
284 let mut parts = Vec::new();
285
286 #(#image_field_parts)*
288
289 let text = llm_toolkit::prompt::render_prompt(#template_str, self).unwrap_or_else(|e| {
291 format!("Failed to render prompt: {}", e)
292 });
293 if !text.is_empty() {
294 parts.push(llm_toolkit::prompt::PromptPart::Text(text));
295 }
296
297 parts
298 }
299
300 fn to_prompt(&self) -> String {
301 llm_toolkit::prompt::render_prompt(#template_str, self).unwrap_or_else(|e| {
302 format!("Failed to render prompt: {}", e)
303 })
304 }
305 }
306 }
307 } else {
308 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
311 &fields.named
312 } else {
313 panic!(
314 "Default prompt generation is only supported for structs with named fields."
315 );
316 };
317
318 let mut text_field_parts = Vec::new();
320 let mut image_field_parts = Vec::new();
321
322 for f in fields.iter() {
323 let field_name = f.ident.as_ref().unwrap();
324 let attrs = parse_field_prompt_attrs(&f.attrs);
325
326 if attrs.skip {
328 continue;
329 }
330
331 if attrs.image {
332 image_field_parts.push(quote! {
334 parts.extend(self.#field_name.to_prompt_parts());
335 });
336 } else {
337 let key = if let Some(rename) = attrs.rename {
343 rename
344 } else {
345 let doc_comment = extract_doc_comments(&f.attrs);
346 if !doc_comment.is_empty() {
347 doc_comment
348 } else {
349 field_name.to_string()
350 }
351 };
352
353 let value_expr = if let Some(format_with) = attrs.format_with {
355 let func_path: syn::Path =
357 syn::parse_str(&format_with).unwrap_or_else(|_| {
358 panic!("Invalid function path: {}", format_with)
359 });
360 quote! { #func_path(&self.#field_name) }
361 } else {
362 quote! { self.#field_name.to_prompt() }
363 };
364
365 text_field_parts.push(quote! {
366 text_parts.push(format!("{}: {}", #key, #value_expr));
367 });
368 }
369 }
370
371 quote! {
373 impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
374 fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
375 let mut parts = Vec::new();
376
377 #(#image_field_parts)*
379
380 let mut text_parts = Vec::new();
382 #(#text_field_parts)*
383
384 if !text_parts.is_empty() {
385 parts.push(llm_toolkit::prompt::PromptPart::Text(text_parts.join("\n")));
386 }
387
388 parts
389 }
390
391 fn to_prompt(&self) -> String {
392 let mut text_parts = Vec::new();
393 #(#text_field_parts)*
394 text_parts.join("\n")
395 }
396 }
397 }
398 };
399
400 TokenStream::from(expanded)
401 }
402 Data::Union(_) => {
403 panic!("`#[derive(ToPrompt)]` is not supported for unions");
404 }
405 }
406}
407
408#[derive(Debug, Clone)]
410struct TargetInfo {
411 name: String,
412 template: Option<String>,
413 field_configs: std::collections::HashMap<String, FieldTargetConfig>,
414}
415
416#[derive(Debug, Clone, Default)]
418struct FieldTargetConfig {
419 skip: bool,
420 rename: Option<String>,
421 format_with: Option<String>,
422 image: bool,
423 include_only: bool, }
425
426fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
428 let mut configs = Vec::new();
429
430 for attr in attrs {
431 if attr.path().is_ident("prompt_for")
432 && let Ok(meta_list) = attr.meta.require_list()
433 {
434 if meta_list.tokens.to_string() == "skip" {
436 let config = FieldTargetConfig {
438 skip: true,
439 ..Default::default()
440 };
441 configs.push(("*".to_string(), config));
442 } else if let Ok(metas) =
443 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
444 {
445 let mut target_name = None;
446 let mut config = FieldTargetConfig::default();
447
448 for meta in metas {
449 match meta {
450 Meta::NameValue(nv) if nv.path.is_ident("name") => {
451 if let syn::Expr::Lit(syn::ExprLit {
452 lit: syn::Lit::Str(lit_str),
453 ..
454 }) = nv.value
455 {
456 target_name = Some(lit_str.value());
457 }
458 }
459 Meta::Path(path) if path.is_ident("skip") => {
460 config.skip = true;
461 }
462 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
463 if let syn::Expr::Lit(syn::ExprLit {
464 lit: syn::Lit::Str(lit_str),
465 ..
466 }) = nv.value
467 {
468 config.rename = Some(lit_str.value());
469 }
470 }
471 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
472 if let syn::Expr::Lit(syn::ExprLit {
473 lit: syn::Lit::Str(lit_str),
474 ..
475 }) = nv.value
476 {
477 config.format_with = Some(lit_str.value());
478 }
479 }
480 Meta::Path(path) if path.is_ident("image") => {
481 config.image = true;
482 }
483 _ => {}
484 }
485 }
486
487 if let Some(name) = target_name {
488 config.include_only = true;
489 configs.push((name, config));
490 }
491 }
492 }
493 }
494
495 configs
496}
497
498fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
500 let mut targets = Vec::new();
501
502 for attr in attrs {
503 if attr.path().is_ident("prompt_for")
504 && let Ok(meta_list) = attr.meta.require_list()
505 && let Ok(metas) =
506 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
507 {
508 let mut target_name = None;
509 let mut template = None;
510
511 for meta in metas {
512 match meta {
513 Meta::NameValue(nv) if nv.path.is_ident("name") => {
514 if let syn::Expr::Lit(syn::ExprLit {
515 lit: syn::Lit::Str(lit_str),
516 ..
517 }) = nv.value
518 {
519 target_name = Some(lit_str.value());
520 }
521 }
522 Meta::NameValue(nv) if nv.path.is_ident("template") => {
523 if let syn::Expr::Lit(syn::ExprLit {
524 lit: syn::Lit::Str(lit_str),
525 ..
526 }) = nv.value
527 {
528 template = Some(lit_str.value());
529 }
530 }
531 _ => {}
532 }
533 }
534
535 if let Some(name) = target_name {
536 targets.push(TargetInfo {
537 name,
538 template,
539 field_configs: std::collections::HashMap::new(),
540 });
541 }
542 }
543 }
544
545 targets
546}
547
548#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
549pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
550 let input = parse_macro_input!(input as DeriveInput);
551
552 let data_struct = match &input.data {
554 Data::Struct(data) => data,
555 _ => {
556 return syn::Error::new(
557 input.ident.span(),
558 "`#[derive(ToPromptSet)]` is only supported for structs",
559 )
560 .to_compile_error()
561 .into();
562 }
563 };
564
565 let fields = match &data_struct.fields {
566 syn::Fields::Named(fields) => &fields.named,
567 _ => {
568 return syn::Error::new(
569 input.ident.span(),
570 "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
571 )
572 .to_compile_error()
573 .into();
574 }
575 };
576
577 let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
579
580 for field in fields.iter() {
582 let field_name = field.ident.as_ref().unwrap().to_string();
583 let field_configs = parse_prompt_for_attrs(&field.attrs);
584
585 for (target_name, config) in field_configs {
586 if target_name == "*" {
587 for target in &mut targets {
589 target
590 .field_configs
591 .entry(field_name.clone())
592 .or_insert_with(FieldTargetConfig::default)
593 .skip = config.skip;
594 }
595 } else {
596 let target_exists = targets.iter().any(|t| t.name == target_name);
598 if !target_exists {
599 targets.push(TargetInfo {
601 name: target_name.clone(),
602 template: None,
603 field_configs: std::collections::HashMap::new(),
604 });
605 }
606
607 let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
608
609 target.field_configs.insert(field_name.clone(), config);
610 }
611 }
612 }
613
614 let mut match_arms = Vec::new();
616
617 for target in &targets {
618 let target_name = &target.name;
619
620 if let Some(template_str) = &target.template {
621 let mut image_parts = Vec::new();
623
624 for field in fields.iter() {
625 let field_name = field.ident.as_ref().unwrap();
626 let field_name_str = field_name.to_string();
627
628 if let Some(config) = target.field_configs.get(&field_name_str)
629 && config.image
630 {
631 image_parts.push(quote! {
632 parts.extend(self.#field_name.to_prompt_parts());
633 });
634 }
635 }
636
637 match_arms.push(quote! {
638 #target_name => {
639 let mut parts = Vec::new();
640
641 #(#image_parts)*
642
643 let text = llm_toolkit::prompt::render_prompt(#template_str, self)
644 .map_err(|e| llm_toolkit::prompt::PromptSetError::RenderFailed {
645 target: #target_name.to_string(),
646 source: e,
647 })?;
648
649 if !text.is_empty() {
650 parts.push(llm_toolkit::prompt::PromptPart::Text(text));
651 }
652
653 Ok(parts)
654 }
655 });
656 } else {
657 let mut text_field_parts = Vec::new();
659 let mut image_field_parts = Vec::new();
660
661 for field in fields.iter() {
662 let field_name = field.ident.as_ref().unwrap();
663 let field_name_str = field_name.to_string();
664
665 let config = target.field_configs.get(&field_name_str);
667
668 if let Some(cfg) = config
670 && cfg.skip
671 {
672 continue;
673 }
674
675 let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
679 let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
680 .iter()
681 .any(|(name, _)| name != "*");
682
683 if has_any_target_specific_config && !is_explicitly_for_this_target {
684 continue;
685 }
686
687 if let Some(cfg) = config {
688 if cfg.image {
689 image_field_parts.push(quote! {
690 parts.extend(self.#field_name.to_prompt_parts());
691 });
692 } else {
693 let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
694
695 let value_expr = if let Some(format_with) = &cfg.format_with {
696 match syn::parse_str::<syn::Path>(format_with) {
698 Ok(func_path) => quote! { #func_path(&self.#field_name) },
699 Err(_) => {
700 let error_msg = format!(
702 "Invalid function path in format_with: '{}'",
703 format_with
704 );
705 quote! {
706 compile_error!(#error_msg);
707 String::new()
708 }
709 }
710 }
711 } else {
712 quote! { self.#field_name.to_prompt() }
713 };
714
715 text_field_parts.push(quote! {
716 text_parts.push(format!("{}: {}", #key, #value_expr));
717 });
718 }
719 } else {
720 text_field_parts.push(quote! {
722 text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
723 });
724 }
725 }
726
727 match_arms.push(quote! {
728 #target_name => {
729 let mut parts = Vec::new();
730
731 #(#image_field_parts)*
732
733 let mut text_parts = Vec::new();
734 #(#text_field_parts)*
735
736 if !text_parts.is_empty() {
737 parts.push(llm_toolkit::prompt::PromptPart::Text(text_parts.join("\n")));
738 }
739
740 Ok(parts)
741 }
742 });
743 }
744 }
745
746 let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
748
749 match_arms.push(quote! {
751 _ => {
752 let available = vec![#(#target_names.to_string()),*];
753 Err(llm_toolkit::prompt::PromptSetError::TargetNotFound {
754 target: target.to_string(),
755 available,
756 })
757 }
758 });
759
760 let struct_name = &input.ident;
761 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
762
763 let expanded = quote! {
764 impl #impl_generics llm_toolkit::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
765 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<llm_toolkit::prompt::PromptPart>, llm_toolkit::prompt::PromptSetError> {
766 match target {
767 #(#match_arms)*
768 }
769 }
770 }
771 };
772
773 TokenStream::from(expanded)
774}