1use proc_macro::TokenStream;
2use quote::quote;
3use regex::Regex;
4use syn::{
5 Data, DeriveInput, Meta, Token,
6 parse::{Parse, ParseStream},
7 parse_macro_input,
8 punctuated::Punctuated,
9};
10
11fn parse_template_placeholders_with_mode(template: &str) -> Vec<(String, Option<String>)> {
14 let mut placeholders = Vec::new();
15 let mut seen_fields = std::collections::HashSet::new();
16
17 let mode_pattern = Regex::new(r"\{\{\s*(\w+)\s*:\s*(\w+)\s*\}\}").unwrap();
19 for cap in mode_pattern.captures_iter(template) {
20 let field_name = cap[1].to_string();
21 let mode = cap[2].to_string();
22 placeholders.push((field_name.clone(), Some(mode)));
23 seen_fields.insert(field_name);
24 }
25
26 let standard_pattern = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
28 for cap in standard_pattern.captures_iter(template) {
29 let field_name = cap[1].to_string();
30 if !seen_fields.contains(&field_name) {
32 placeholders.push((field_name, None));
33 }
34 }
35
36 placeholders
37}
38
39fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
41 attrs
42 .iter()
43 .filter_map(|attr| {
44 if attr.path().is_ident("doc")
45 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
46 && let syn::Expr::Lit(syn::ExprLit {
47 lit: syn::Lit::Str(lit_str),
48 ..
49 }) = &meta_name_value.value
50 {
51 return Some(lit_str.value());
52 }
53 None
54 })
55 .map(|s| s.trim().to_string())
56 .collect::<Vec<_>>()
57 .join(" ")
58}
59
60fn generate_example_only_parts(
62 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
63 has_default: bool,
64) -> proc_macro2::TokenStream {
65 let mut field_values = Vec::new();
66
67 for field in fields.iter() {
68 let field_name = field.ident.as_ref().unwrap();
69 let field_name_str = field_name.to_string();
70 let attrs = parse_field_prompt_attrs(&field.attrs);
71
72 if attrs.skip {
74 continue;
75 }
76
77 if let Some(example) = attrs.example {
79 field_values.push(quote! {
81 json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
82 });
83 } else if has_default {
84 field_values.push(quote! {
86 let default_value = serde_json::to_value(&default_instance.#field_name)
87 .unwrap_or(serde_json::Value::Null);
88 json_obj.insert(#field_name_str.to_string(), default_value);
89 });
90 } else {
91 field_values.push(quote! {
93 let value = serde_json::to_value(&self.#field_name)
94 .unwrap_or(serde_json::Value::Null);
95 json_obj.insert(#field_name_str.to_string(), value);
96 });
97 }
98 }
99
100 if has_default {
101 quote! {
102 {
103 let default_instance = Self::default();
104 let mut json_obj = serde_json::Map::new();
105 #(#field_values)*
106 let json_value = serde_json::Value::Object(json_obj);
107 let json_str = serde_json::to_string_pretty(&json_value)
108 .unwrap_or_else(|_| "{}".to_string());
109 vec![llm_toolkit::prompt::PromptPart::Text(json_str)]
110 }
111 }
112 } else {
113 quote! {
114 {
115 let mut json_obj = serde_json::Map::new();
116 #(#field_values)*
117 let json_value = serde_json::Value::Object(json_obj);
118 let json_str = serde_json::to_string_pretty(&json_value)
119 .unwrap_or_else(|_| "{}".to_string());
120 vec![llm_toolkit::prompt::PromptPart::Text(json_str)]
121 }
122 }
123 }
124}
125
126fn generate_schema_only_parts(
128 struct_name: &str,
129 struct_docs: &str,
130 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
131) -> proc_macro2::TokenStream {
132 let mut schema_lines = vec![];
133
134 if !struct_docs.is_empty() {
136 schema_lines.push(format!("### Schema for `{}`\n{}", struct_name, struct_docs));
137 } else {
138 schema_lines.push(format!("### Schema for `{}`", struct_name));
139 }
140
141 schema_lines.push("{".to_string());
142
143 for (i, field) in fields.iter().enumerate() {
145 let field_name = field.ident.as_ref().unwrap();
146 let attrs = parse_field_prompt_attrs(&field.attrs);
147
148 if attrs.skip {
150 continue;
151 }
152
153 let field_docs = extract_doc_comments(&field.attrs);
155
156 let type_str = format_type_for_schema(&field.ty);
158
159 let mut field_line = format!(" \"{}\": \"{}\"", field_name, type_str);
161
162 if !field_docs.is_empty() {
164 field_line.push_str(&format!(", // {}", field_docs));
165 }
166
167 let remaining_fields = fields
169 .iter()
170 .skip(i + 1)
171 .filter(|f| {
172 let attrs = parse_field_prompt_attrs(&f.attrs);
173 !attrs.skip
174 })
175 .count();
176
177 if remaining_fields > 0 {
178 field_line.push(',');
179 }
180
181 schema_lines.push(field_line);
182 }
183
184 schema_lines.push("}".to_string());
185
186 let schema_str = schema_lines.join("\n");
187
188 quote! {
189 vec![llm_toolkit::prompt::PromptPart::Text(#schema_str.to_string())]
190 }
191}
192
193fn format_type_for_schema(ty: &syn::Type) -> String {
195 match ty {
197 syn::Type::Path(type_path) => {
198 let path = &type_path.path;
199 if let Some(last_segment) = path.segments.last() {
200 let type_name = last_segment.ident.to_string();
201
202 if type_name == "Option"
204 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
205 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
206 {
207 return format!("{} | null", format_type_for_schema(inner_type));
208 }
209
210 match type_name.as_str() {
212 "String" | "str" => "string".to_string(),
213 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
214 | "u64" | "u128" | "usize" => "number".to_string(),
215 "f32" | "f64" => "number".to_string(),
216 "bool" => "boolean".to_string(),
217 "Vec" => {
218 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
219 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
220 {
221 return format!("{}[]", format_type_for_schema(inner_type));
222 }
223 "array".to_string()
224 }
225 _ => type_name.to_lowercase(),
226 }
227 } else {
228 "unknown".to_string()
229 }
230 }
231 _ => "unknown".to_string(),
232 }
233}
234
235enum PromptAttribute {
237 Skip,
238 Description(String),
239 None,
240}
241
242fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
244 for attr in attrs {
245 if attr.path().is_ident("prompt") {
246 if let Ok(meta_list) = attr.meta.require_list() {
248 let tokens = &meta_list.tokens;
249 let tokens_str = tokens.to_string();
250 if tokens_str == "skip" {
251 return PromptAttribute::Skip;
252 }
253 }
254
255 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
257 return PromptAttribute::Description(lit_str.value());
258 }
259 }
260 }
261 PromptAttribute::None
262}
263
264#[derive(Debug, Default)]
266struct FieldPromptAttrs {
267 skip: bool,
268 rename: Option<String>,
269 format_with: Option<String>,
270 image: bool,
271 example: Option<String>,
272}
273
274fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
276 let mut result = FieldPromptAttrs::default();
277
278 for attr in attrs {
279 if attr.path().is_ident("prompt") {
280 if let Ok(meta_list) = attr.meta.require_list() {
282 if let Ok(metas) =
284 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
285 {
286 for meta in metas {
287 match meta {
288 Meta::Path(path) if path.is_ident("skip") => {
289 result.skip = true;
290 }
291 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
292 if let syn::Expr::Lit(syn::ExprLit {
293 lit: syn::Lit::Str(lit_str),
294 ..
295 }) = nv.value
296 {
297 result.rename = Some(lit_str.value());
298 }
299 }
300 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
301 if let syn::Expr::Lit(syn::ExprLit {
302 lit: syn::Lit::Str(lit_str),
303 ..
304 }) = nv.value
305 {
306 result.format_with = Some(lit_str.value());
307 }
308 }
309 Meta::Path(path) if path.is_ident("image") => {
310 result.image = true;
311 }
312 Meta::NameValue(nv) if nv.path.is_ident("example") => {
313 if let syn::Expr::Lit(syn::ExprLit {
314 lit: syn::Lit::Str(lit_str),
315 ..
316 }) = nv.value
317 {
318 result.example = Some(lit_str.value());
319 }
320 }
321 _ => {}
322 }
323 }
324 } else if meta_list.tokens.to_string() == "skip" {
325 result.skip = true;
327 } else if meta_list.tokens.to_string() == "image" {
328 result.image = true;
330 }
331 }
332 }
333 }
334
335 result
336}
337
338#[proc_macro_derive(ToPrompt, attributes(prompt))]
381pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
382 let input = parse_macro_input!(input as DeriveInput);
383
384 match &input.data {
386 Data::Enum(data_enum) => {
387 let enum_name = &input.ident;
389 let enum_docs = extract_doc_comments(&input.attrs);
390
391 let mut prompt_lines = Vec::new();
392
393 if !enum_docs.is_empty() {
395 prompt_lines.push(format!("{}: {}", enum_name, enum_docs));
396 } else {
397 prompt_lines.push(format!("{}:", enum_name));
398 }
399 prompt_lines.push(String::new()); prompt_lines.push("Possible values:".to_string());
401
402 for variant in &data_enum.variants {
404 let variant_name = &variant.ident;
405
406 match parse_prompt_attribute(&variant.attrs) {
408 PromptAttribute::Skip => {
409 continue;
411 }
412 PromptAttribute::Description(desc) => {
413 prompt_lines.push(format!("- {}: {}", variant_name, desc));
415 }
416 PromptAttribute::None => {
417 let variant_docs = extract_doc_comments(&variant.attrs);
419 if !variant_docs.is_empty() {
420 prompt_lines.push(format!("- {}: {}", variant_name, variant_docs));
421 } else {
422 prompt_lines.push(format!("- {}", variant_name));
423 }
424 }
425 }
426 }
427
428 let prompt_string = prompt_lines.join("\n");
429 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
430
431 let expanded = quote! {
432 impl #impl_generics llm_toolkit::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
433 fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
434 vec![llm_toolkit::prompt::PromptPart::Text(#prompt_string.to_string())]
435 }
436
437 fn to_prompt(&self) -> String {
438 #prompt_string.to_string()
439 }
440 }
441 };
442
443 TokenStream::from(expanded)
444 }
445 Data::Struct(data_struct) => {
446 let mut template_attr = None;
448 let mut template_file_attr = None;
449 let mut mode_attr = None;
450 let mut validate_attr = false;
451
452 for attr in &input.attrs {
453 if attr.path().is_ident("prompt") {
454 if let Ok(metas) =
456 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
457 {
458 for meta in metas {
459 match meta {
460 Meta::NameValue(nv) if nv.path.is_ident("template") => {
461 if let syn::Expr::Lit(expr_lit) = nv.value
462 && let syn::Lit::Str(lit_str) = expr_lit.lit
463 {
464 template_attr = Some(lit_str.value());
465 }
466 }
467 Meta::NameValue(nv) if nv.path.is_ident("template_file") => {
468 if let syn::Expr::Lit(expr_lit) = nv.value
469 && let syn::Lit::Str(lit_str) = expr_lit.lit
470 {
471 template_file_attr = Some(lit_str.value());
472 }
473 }
474 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
475 if let syn::Expr::Lit(expr_lit) = nv.value
476 && let syn::Lit::Str(lit_str) = expr_lit.lit
477 {
478 mode_attr = Some(lit_str.value());
479 }
480 }
481 Meta::NameValue(nv) if nv.path.is_ident("validate") => {
482 if let syn::Expr::Lit(expr_lit) = nv.value
483 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
484 {
485 validate_attr = lit_bool.value();
486 }
487 }
488 _ => {}
489 }
490 }
491 }
492 }
493 }
494
495 if template_attr.is_some() && template_file_attr.is_some() {
497 return syn::Error::new(
498 input.ident.span(),
499 "The `template` and `template_file` attributes are mutually exclusive. Please use only one.",
500 ).to_compile_error().into();
501 }
502
503 let template_str = if let Some(file_path) = template_file_attr {
505 let mut full_path = None;
509
510 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
512 let is_trybuild = manifest_dir.contains("target/tests/trybuild");
514
515 if !is_trybuild {
516 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
518 if candidate.exists() {
519 full_path = Some(candidate);
520 }
521 } else {
522 if let Some(target_pos) = manifest_dir.find("/target/tests/trybuild") {
528 let workspace_root = &manifest_dir[..target_pos];
529 let original_macros_dir = std::path::Path::new(workspace_root)
531 .join("crates")
532 .join("llm-toolkit-macros");
533
534 let candidate = original_macros_dir.join(&file_path);
535 if candidate.exists() {
536 full_path = Some(candidate);
537 }
538 }
539 }
540 }
541
542 if full_path.is_none() {
544 let candidate = std::path::Path::new(&file_path).to_path_buf();
545 if candidate.exists() {
546 full_path = Some(candidate);
547 }
548 }
549
550 if full_path.is_none()
553 && let Ok(current_dir) = std::env::current_dir()
554 {
555 let mut search_dir = current_dir.as_path();
556 for _ in 0..10 {
558 let macros_dir = search_dir.join("crates/llm-toolkit-macros");
560 if macros_dir.exists() {
561 let candidate = macros_dir.join(&file_path);
562 if candidate.exists() {
563 full_path = Some(candidate);
564 break;
565 }
566 }
567 let candidate = search_dir.join(&file_path);
569 if candidate.exists() {
570 full_path = Some(candidate);
571 break;
572 }
573 if let Some(parent) = search_dir.parent() {
574 search_dir = parent;
575 } else {
576 break;
577 }
578 }
579 }
580
581 let final_path =
583 full_path.unwrap_or_else(|| std::path::Path::new(&file_path).to_path_buf());
584
585 match std::fs::read_to_string(&final_path) {
587 Ok(content) => Some(content),
588 Err(e) => {
589 return syn::Error::new(
590 input.ident.span(),
591 format!(
592 "Failed to read template file '{}': {}",
593 final_path.display(),
594 e
595 ),
596 )
597 .to_compile_error()
598 .into();
599 }
600 }
601 } else {
602 template_attr
603 };
604
605 if validate_attr && let Some(template) = &template_str {
607 let mut env = minijinja::Environment::new();
609 if let Err(e) = env.add_template("validation", template) {
610 let warning_msg =
612 format!("Template validation warning: Invalid Jinja syntax - {}", e);
613 let warning_ident = syn::Ident::new(
614 "TEMPLATE_VALIDATION_WARNING",
615 proc_macro2::Span::call_site(),
616 );
617 let _warning_tokens = quote! {
618 #[deprecated(note = #warning_msg)]
619 const #warning_ident: () = ();
620 let _ = #warning_ident;
621 };
622 eprintln!("cargo:warning={}", warning_msg);
624 }
625
626 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
628 &fields.named
629 } else {
630 panic!("Template validation is only supported for structs with named fields.");
631 };
632
633 let field_names: std::collections::HashSet<String> = fields
634 .iter()
635 .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
636 .collect();
637
638 let placeholders = parse_template_placeholders_with_mode(template);
640
641 for (placeholder_name, _mode) in &placeholders {
642 if placeholder_name != "self" && !field_names.contains(placeholder_name) {
643 let warning_msg = format!(
644 "Template validation warning: Variable '{}' used in template but not found in struct fields",
645 placeholder_name
646 );
647 eprintln!("cargo:warning={}", warning_msg);
648 }
649 }
650 }
651
652 let name = input.ident;
653 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
654
655 let struct_docs = extract_doc_comments(&input.attrs);
657
658 let is_mode_based =
660 mode_attr.is_some() || (template_str.is_none() && struct_docs.contains("mode"));
661
662 let expanded = if is_mode_based || mode_attr.is_some() {
663 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
665 &fields.named
666 } else {
667 panic!(
668 "Mode-based prompt generation is only supported for structs with named fields."
669 );
670 };
671
672 let struct_name_str = name.to_string();
673
674 let has_default = input.attrs.iter().any(|attr| {
676 if attr.path().is_ident("derive")
677 && let Ok(meta_list) = attr.meta.require_list()
678 {
679 let tokens_str = meta_list.tokens.to_string();
680 tokens_str.contains("Default")
681 } else {
682 false
683 }
684 });
685
686 let schema_parts =
688 generate_schema_only_parts(&struct_name_str, &struct_docs, fields);
689
690 let example_parts = generate_example_only_parts(fields, has_default);
692
693 quote! {
694 impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
695 fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<llm_toolkit::prompt::PromptPart> {
696 match mode {
697 "schema_only" => #schema_parts,
698 "example_only" => #example_parts,
699 "full" | _ => {
700 let mut parts = Vec::new();
702
703 let schema_parts = #schema_parts;
705 parts.extend(schema_parts);
706
707 parts.push(llm_toolkit::prompt::PromptPart::Text("\n### Example".to_string()));
709 parts.push(llm_toolkit::prompt::PromptPart::Text(
710 format!("Here is an example of a valid `{}` object:", #struct_name_str)
711 ));
712
713 let example_parts = #example_parts;
715 parts.extend(example_parts);
716
717 parts
718 }
719 }
720 }
721
722 fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
723 self.to_prompt_parts_with_mode("full")
724 }
725
726 fn to_prompt(&self) -> String {
727 self.to_prompt_parts()
728 .into_iter()
729 .filter_map(|part| match part {
730 llm_toolkit::prompt::PromptPart::Text(text) => Some(text),
731 _ => None,
732 })
733 .collect::<Vec<_>>()
734 .join("\n")
735 }
736 }
737 }
738 } else if let Some(template) = template_str {
739 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
742 &fields.named
743 } else {
744 panic!(
745 "Template prompt generation is only supported for structs with named fields."
746 );
747 };
748
749 let placeholders = parse_template_placeholders_with_mode(&template);
751 let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
753 mode.is_some()
754 && fields
755 .iter()
756 .any(|f| f.ident.as_ref().unwrap() == field_name)
757 });
758
759 let mut image_field_parts = Vec::new();
760 for f in fields.iter() {
761 let field_name = f.ident.as_ref().unwrap();
762 let attrs = parse_field_prompt_attrs(&f.attrs);
763
764 if attrs.image {
765 image_field_parts.push(quote! {
767 parts.extend(self.#field_name.to_prompt_parts());
768 });
769 }
770 }
771
772 if has_mode_syntax {
774 let mut context_fields = Vec::new();
776 let mut modified_template = template.clone();
777
778 for (field_name, mode_opt) in &placeholders {
780 if let Some(mode) = mode_opt {
781 let unique_key = format!("{}__{}", field_name, mode);
783
784 let pattern = format!("{{{{ {}:{} }}}}", field_name, mode);
786 let replacement = format!("{{{{ {} }}}}", unique_key);
787 modified_template = modified_template.replace(&pattern, &replacement);
788
789 let field_ident =
791 syn::Ident::new(field_name, proc_macro2::Span::call_site());
792
793 context_fields.push(quote! {
795 context.insert(
796 #unique_key.to_string(),
797 minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
798 );
799 });
800 }
801 }
802
803 for field in fields.iter() {
805 let field_name = field.ident.as_ref().unwrap();
806 let field_name_str = field_name.to_string();
807
808 let has_mode_entry = placeholders
810 .iter()
811 .any(|(name, mode)| name == &field_name_str && mode.is_some());
812
813 if !has_mode_entry {
814 let is_primitive = match &field.ty {
817 syn::Type::Path(type_path) => {
818 if let Some(segment) = type_path.path.segments.last() {
819 let type_name = segment.ident.to_string();
820 matches!(
821 type_name.as_str(),
822 "String"
823 | "str"
824 | "i8"
825 | "i16"
826 | "i32"
827 | "i64"
828 | "i128"
829 | "isize"
830 | "u8"
831 | "u16"
832 | "u32"
833 | "u64"
834 | "u128"
835 | "usize"
836 | "f32"
837 | "f64"
838 | "bool"
839 | "char"
840 )
841 } else {
842 false
843 }
844 }
845 _ => false,
846 };
847
848 if is_primitive {
849 context_fields.push(quote! {
850 context.insert(
851 #field_name_str.to_string(),
852 minijinja::Value::from_serialize(&self.#field_name)
853 );
854 });
855 } else {
856 context_fields.push(quote! {
858 context.insert(
859 #field_name_str.to_string(),
860 minijinja::Value::from(self.#field_name.to_prompt())
861 );
862 });
863 }
864 }
865 }
866
867 quote! {
868 impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
869 fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
870 let mut parts = Vec::new();
871
872 #(#image_field_parts)*
874
875 let text = {
877 let mut env = minijinja::Environment::new();
878 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
879 panic!("Failed to parse template: {}", e)
880 });
881
882 let tmpl = env.get_template("prompt").unwrap();
883
884 let mut context = std::collections::HashMap::new();
885 #(#context_fields)*
886
887 tmpl.render(context).unwrap_or_else(|e| {
888 format!("Failed to render prompt: {}", e)
889 })
890 };
891
892 if !text.is_empty() {
893 parts.push(llm_toolkit::prompt::PromptPart::Text(text));
894 }
895
896 parts
897 }
898
899 fn to_prompt(&self) -> String {
900 let mut env = minijinja::Environment::new();
902 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
903 panic!("Failed to parse template: {}", e)
904 });
905
906 let tmpl = env.get_template("prompt").unwrap();
907
908 let mut context = std::collections::HashMap::new();
909 #(#context_fields)*
910
911 tmpl.render(context).unwrap_or_else(|e| {
912 format!("Failed to render prompt: {}", e)
913 })
914 }
915 }
916 }
917 } else {
918 quote! {
920 impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
921 fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
922 let mut parts = Vec::new();
923
924 #(#image_field_parts)*
926
927 let text = llm_toolkit::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
929 format!("Failed to render prompt: {}", e)
930 });
931 if !text.is_empty() {
932 parts.push(llm_toolkit::prompt::PromptPart::Text(text));
933 }
934
935 parts
936 }
937
938 fn to_prompt(&self) -> String {
939 llm_toolkit::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
940 format!("Failed to render prompt: {}", e)
941 })
942 }
943 }
944 }
945 }
946 } else {
947 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
950 &fields.named
951 } else {
952 panic!(
953 "Default prompt generation is only supported for structs with named fields."
954 );
955 };
956
957 let mut text_field_parts = Vec::new();
959 let mut image_field_parts = Vec::new();
960
961 for f in fields.iter() {
962 let field_name = f.ident.as_ref().unwrap();
963 let attrs = parse_field_prompt_attrs(&f.attrs);
964
965 if attrs.skip {
967 continue;
968 }
969
970 if attrs.image {
971 image_field_parts.push(quote! {
973 parts.extend(self.#field_name.to_prompt_parts());
974 });
975 } else {
976 let key = if let Some(rename) = attrs.rename {
982 rename
983 } else {
984 let doc_comment = extract_doc_comments(&f.attrs);
985 if !doc_comment.is_empty() {
986 doc_comment
987 } else {
988 field_name.to_string()
989 }
990 };
991
992 let value_expr = if let Some(format_with) = attrs.format_with {
994 let func_path: syn::Path =
996 syn::parse_str(&format_with).unwrap_or_else(|_| {
997 panic!("Invalid function path: {}", format_with)
998 });
999 quote! { #func_path(&self.#field_name) }
1000 } else {
1001 quote! { self.#field_name.to_prompt() }
1002 };
1003
1004 text_field_parts.push(quote! {
1005 text_parts.push(format!("{}: {}", #key, #value_expr));
1006 });
1007 }
1008 }
1009
1010 quote! {
1012 impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
1013 fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
1014 let mut parts = Vec::new();
1015
1016 #(#image_field_parts)*
1018
1019 let mut text_parts = Vec::new();
1021 #(#text_field_parts)*
1022
1023 if !text_parts.is_empty() {
1024 parts.push(llm_toolkit::prompt::PromptPart::Text(text_parts.join("\n")));
1025 }
1026
1027 parts
1028 }
1029
1030 fn to_prompt(&self) -> String {
1031 let mut text_parts = Vec::new();
1032 #(#text_field_parts)*
1033 text_parts.join("\n")
1034 }
1035 }
1036 }
1037 };
1038
1039 TokenStream::from(expanded)
1040 }
1041 Data::Union(_) => {
1042 panic!("`#[derive(ToPrompt)]` is not supported for unions");
1043 }
1044 }
1045}
1046
1047#[derive(Debug, Clone)]
1049struct TargetInfo {
1050 name: String,
1051 template: Option<String>,
1052 field_configs: std::collections::HashMap<String, FieldTargetConfig>,
1053}
1054
1055#[derive(Debug, Clone, Default)]
1057struct FieldTargetConfig {
1058 skip: bool,
1059 rename: Option<String>,
1060 format_with: Option<String>,
1061 image: bool,
1062 include_only: bool, }
1064
1065fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
1067 let mut configs = Vec::new();
1068
1069 for attr in attrs {
1070 if attr.path().is_ident("prompt_for")
1071 && let Ok(meta_list) = attr.meta.require_list()
1072 {
1073 if meta_list.tokens.to_string() == "skip" {
1075 let config = FieldTargetConfig {
1077 skip: true,
1078 ..Default::default()
1079 };
1080 configs.push(("*".to_string(), config));
1081 } else if let Ok(metas) =
1082 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1083 {
1084 let mut target_name = None;
1085 let mut config = FieldTargetConfig::default();
1086
1087 for meta in metas {
1088 match meta {
1089 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1090 if let syn::Expr::Lit(syn::ExprLit {
1091 lit: syn::Lit::Str(lit_str),
1092 ..
1093 }) = nv.value
1094 {
1095 target_name = Some(lit_str.value());
1096 }
1097 }
1098 Meta::Path(path) if path.is_ident("skip") => {
1099 config.skip = true;
1100 }
1101 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
1102 if let syn::Expr::Lit(syn::ExprLit {
1103 lit: syn::Lit::Str(lit_str),
1104 ..
1105 }) = nv.value
1106 {
1107 config.rename = Some(lit_str.value());
1108 }
1109 }
1110 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
1111 if let syn::Expr::Lit(syn::ExprLit {
1112 lit: syn::Lit::Str(lit_str),
1113 ..
1114 }) = nv.value
1115 {
1116 config.format_with = Some(lit_str.value());
1117 }
1118 }
1119 Meta::Path(path) if path.is_ident("image") => {
1120 config.image = true;
1121 }
1122 _ => {}
1123 }
1124 }
1125
1126 if let Some(name) = target_name {
1127 config.include_only = true;
1128 configs.push((name, config));
1129 }
1130 }
1131 }
1132 }
1133
1134 configs
1135}
1136
1137fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
1139 let mut targets = Vec::new();
1140
1141 for attr in attrs {
1142 if attr.path().is_ident("prompt_for")
1143 && let Ok(meta_list) = attr.meta.require_list()
1144 && let Ok(metas) =
1145 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1146 {
1147 let mut target_name = None;
1148 let mut template = None;
1149
1150 for meta in metas {
1151 match meta {
1152 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1153 if let syn::Expr::Lit(syn::ExprLit {
1154 lit: syn::Lit::Str(lit_str),
1155 ..
1156 }) = nv.value
1157 {
1158 target_name = Some(lit_str.value());
1159 }
1160 }
1161 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1162 if let syn::Expr::Lit(syn::ExprLit {
1163 lit: syn::Lit::Str(lit_str),
1164 ..
1165 }) = nv.value
1166 {
1167 template = Some(lit_str.value());
1168 }
1169 }
1170 _ => {}
1171 }
1172 }
1173
1174 if let Some(name) = target_name {
1175 targets.push(TargetInfo {
1176 name,
1177 template,
1178 field_configs: std::collections::HashMap::new(),
1179 });
1180 }
1181 }
1182 }
1183
1184 targets
1185}
1186
1187#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
1188pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
1189 let input = parse_macro_input!(input as DeriveInput);
1190
1191 let data_struct = match &input.data {
1193 Data::Struct(data) => data,
1194 _ => {
1195 return syn::Error::new(
1196 input.ident.span(),
1197 "`#[derive(ToPromptSet)]` is only supported for structs",
1198 )
1199 .to_compile_error()
1200 .into();
1201 }
1202 };
1203
1204 let fields = match &data_struct.fields {
1205 syn::Fields::Named(fields) => &fields.named,
1206 _ => {
1207 return syn::Error::new(
1208 input.ident.span(),
1209 "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
1210 )
1211 .to_compile_error()
1212 .into();
1213 }
1214 };
1215
1216 let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
1218
1219 for field in fields.iter() {
1221 let field_name = field.ident.as_ref().unwrap().to_string();
1222 let field_configs = parse_prompt_for_attrs(&field.attrs);
1223
1224 for (target_name, config) in field_configs {
1225 if target_name == "*" {
1226 for target in &mut targets {
1228 target
1229 .field_configs
1230 .entry(field_name.clone())
1231 .or_insert_with(FieldTargetConfig::default)
1232 .skip = config.skip;
1233 }
1234 } else {
1235 let target_exists = targets.iter().any(|t| t.name == target_name);
1237 if !target_exists {
1238 targets.push(TargetInfo {
1240 name: target_name.clone(),
1241 template: None,
1242 field_configs: std::collections::HashMap::new(),
1243 });
1244 }
1245
1246 let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
1247
1248 target.field_configs.insert(field_name.clone(), config);
1249 }
1250 }
1251 }
1252
1253 let mut match_arms = Vec::new();
1255
1256 for target in &targets {
1257 let target_name = &target.name;
1258
1259 if let Some(template_str) = &target.template {
1260 let mut image_parts = Vec::new();
1262
1263 for field in fields.iter() {
1264 let field_name = field.ident.as_ref().unwrap();
1265 let field_name_str = field_name.to_string();
1266
1267 if let Some(config) = target.field_configs.get(&field_name_str)
1268 && config.image
1269 {
1270 image_parts.push(quote! {
1271 parts.extend(self.#field_name.to_prompt_parts());
1272 });
1273 }
1274 }
1275
1276 match_arms.push(quote! {
1277 #target_name => {
1278 let mut parts = Vec::new();
1279
1280 #(#image_parts)*
1281
1282 let text = llm_toolkit::prompt::render_prompt(#template_str, self)
1283 .map_err(|e| llm_toolkit::prompt::PromptSetError::RenderFailed {
1284 target: #target_name.to_string(),
1285 source: e,
1286 })?;
1287
1288 if !text.is_empty() {
1289 parts.push(llm_toolkit::prompt::PromptPart::Text(text));
1290 }
1291
1292 Ok(parts)
1293 }
1294 });
1295 } else {
1296 let mut text_field_parts = Vec::new();
1298 let mut image_field_parts = Vec::new();
1299
1300 for field in fields.iter() {
1301 let field_name = field.ident.as_ref().unwrap();
1302 let field_name_str = field_name.to_string();
1303
1304 let config = target.field_configs.get(&field_name_str);
1306
1307 if let Some(cfg) = config
1309 && cfg.skip
1310 {
1311 continue;
1312 }
1313
1314 let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
1318 let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
1319 .iter()
1320 .any(|(name, _)| name != "*");
1321
1322 if has_any_target_specific_config && !is_explicitly_for_this_target {
1323 continue;
1324 }
1325
1326 if let Some(cfg) = config {
1327 if cfg.image {
1328 image_field_parts.push(quote! {
1329 parts.extend(self.#field_name.to_prompt_parts());
1330 });
1331 } else {
1332 let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
1333
1334 let value_expr = if let Some(format_with) = &cfg.format_with {
1335 match syn::parse_str::<syn::Path>(format_with) {
1337 Ok(func_path) => quote! { #func_path(&self.#field_name) },
1338 Err(_) => {
1339 let error_msg = format!(
1341 "Invalid function path in format_with: '{}'",
1342 format_with
1343 );
1344 quote! {
1345 compile_error!(#error_msg);
1346 String::new()
1347 }
1348 }
1349 }
1350 } else {
1351 quote! { self.#field_name.to_prompt() }
1352 };
1353
1354 text_field_parts.push(quote! {
1355 text_parts.push(format!("{}: {}", #key, #value_expr));
1356 });
1357 }
1358 } else {
1359 text_field_parts.push(quote! {
1361 text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
1362 });
1363 }
1364 }
1365
1366 match_arms.push(quote! {
1367 #target_name => {
1368 let mut parts = Vec::new();
1369
1370 #(#image_field_parts)*
1371
1372 let mut text_parts = Vec::new();
1373 #(#text_field_parts)*
1374
1375 if !text_parts.is_empty() {
1376 parts.push(llm_toolkit::prompt::PromptPart::Text(text_parts.join("\n")));
1377 }
1378
1379 Ok(parts)
1380 }
1381 });
1382 }
1383 }
1384
1385 let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
1387
1388 match_arms.push(quote! {
1390 _ => {
1391 let available = vec![#(#target_names.to_string()),*];
1392 Err(llm_toolkit::prompt::PromptSetError::TargetNotFound {
1393 target: target.to_string(),
1394 available,
1395 })
1396 }
1397 });
1398
1399 let struct_name = &input.ident;
1400 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1401
1402 let expanded = quote! {
1403 impl #impl_generics llm_toolkit::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
1404 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<llm_toolkit::prompt::PromptPart>, llm_toolkit::prompt::PromptSetError> {
1405 match target {
1406 #(#match_arms)*
1407 }
1408 }
1409 }
1410 };
1411
1412 TokenStream::from(expanded)
1413}
1414
1415struct TypeList {
1417 types: Punctuated<syn::Type, Token![,]>,
1418}
1419
1420impl Parse for TypeList {
1421 fn parse(input: ParseStream) -> syn::Result<Self> {
1422 Ok(TypeList {
1423 types: Punctuated::parse_terminated(input)?,
1424 })
1425 }
1426}
1427
1428#[proc_macro]
1452pub fn examples_section(input: TokenStream) -> TokenStream {
1453 let input = parse_macro_input!(input as TypeList);
1454
1455 let mut type_sections = Vec::new();
1457
1458 for ty in input.types.iter() {
1459 let type_name_str = quote!(#ty).to_string();
1461
1462 type_sections.push(quote! {
1464 {
1465 let type_name = #type_name_str;
1466 let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
1467 format!("---\n#### `{}`\n{}", type_name, json_example)
1468 }
1469 });
1470 }
1471
1472 let expanded = quote! {
1474 {
1475 let mut sections = Vec::new();
1476 sections.push("---".to_string());
1477 sections.push("### Examples".to_string());
1478 sections.push("".to_string());
1479 sections.push("Here are examples of the data structures you should use.".to_string());
1480 sections.push("".to_string());
1481
1482 #(sections.push(#type_sections);)*
1483
1484 sections.push("---".to_string());
1485
1486 sections.join("\n")
1487 }
1488 };
1489
1490 TokenStream::from(expanded)
1491}
1492
1493fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
1495 for attr in attrs {
1496 if attr.path().is_ident("prompt_for")
1497 && let Ok(meta_list) = attr.meta.require_list()
1498 && let Ok(metas) =
1499 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1500 {
1501 let mut target_type = None;
1502 let mut template = None;
1503
1504 for meta in metas {
1505 match meta {
1506 Meta::NameValue(nv) if nv.path.is_ident("target") => {
1507 if let syn::Expr::Lit(syn::ExprLit {
1508 lit: syn::Lit::Str(lit_str),
1509 ..
1510 }) = nv.value
1511 {
1512 target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
1514 }
1515 }
1516 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1517 if let syn::Expr::Lit(syn::ExprLit {
1518 lit: syn::Lit::Str(lit_str),
1519 ..
1520 }) = nv.value
1521 {
1522 template = Some(lit_str.value());
1523 }
1524 }
1525 _ => {}
1526 }
1527 }
1528
1529 if let (Some(target), Some(tmpl)) = (target_type, template) {
1530 return (target, tmpl);
1531 }
1532 }
1533 }
1534
1535 panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
1536}
1537
1538#[proc_macro_attribute]
1572pub fn define_intent(_attr: TokenStream, item: TokenStream) -> TokenStream {
1573 let input = parse_macro_input!(item as DeriveInput);
1574
1575 let enum_data = match &input.data {
1577 Data::Enum(data) => data,
1578 _ => {
1579 return syn::Error::new(
1580 input.ident.span(),
1581 "`#[define_intent]` can only be applied to enums",
1582 )
1583 .to_compile_error()
1584 .into();
1585 }
1586 };
1587
1588 let mut prompt_template = None;
1590 let mut extractor_tag = None;
1591 let mut mode = None;
1592
1593 for attr in &input.attrs {
1594 if attr.path().is_ident("intent")
1595 && let Ok(metas) =
1596 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1597 {
1598 for meta in metas {
1599 match meta {
1600 Meta::NameValue(nv) if nv.path.is_ident("prompt") => {
1601 if let syn::Expr::Lit(syn::ExprLit {
1602 lit: syn::Lit::Str(lit_str),
1603 ..
1604 }) = nv.value
1605 {
1606 prompt_template = Some(lit_str.value());
1607 }
1608 }
1609 Meta::NameValue(nv) if nv.path.is_ident("extractor_tag") => {
1610 if let syn::Expr::Lit(syn::ExprLit {
1611 lit: syn::Lit::Str(lit_str),
1612 ..
1613 }) = nv.value
1614 {
1615 extractor_tag = Some(lit_str.value());
1616 }
1617 }
1618 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
1619 if let syn::Expr::Lit(syn::ExprLit {
1620 lit: syn::Lit::Str(lit_str),
1621 ..
1622 }) = nv.value
1623 {
1624 mode = Some(lit_str.value());
1625 }
1626 }
1627 _ => {}
1628 }
1629 }
1630 }
1631 }
1632
1633 let mode = mode.unwrap_or_else(|| "single".to_string());
1635
1636 if mode != "single" && mode != "multi_tag" {
1638 return syn::Error::new(
1639 input.ident.span(),
1640 "`mode` must be either \"single\" or \"multi_tag\"",
1641 )
1642 .to_compile_error()
1643 .into();
1644 }
1645
1646 let prompt_template = match prompt_template {
1648 Some(p) => p,
1649 None => {
1650 return syn::Error::new(
1651 input.ident.span(),
1652 "`#[intent(...)]` attribute must include `prompt = \"...\"`",
1653 )
1654 .to_compile_error()
1655 .into();
1656 }
1657 };
1658
1659 if mode == "multi_tag" {
1661 let enum_name = &input.ident;
1662 let actions_doc = generate_multi_tag_actions_doc(&enum_data.variants);
1663 return generate_multi_tag_output(
1664 &input,
1665 enum_name,
1666 enum_data,
1667 prompt_template,
1668 actions_doc,
1669 );
1670 }
1671
1672 let extractor_tag = match extractor_tag {
1674 Some(t) => t,
1675 None => {
1676 return syn::Error::new(
1677 input.ident.span(),
1678 "`#[intent(...)]` attribute must include `extractor_tag = \"...\"`",
1679 )
1680 .to_compile_error()
1681 .into();
1682 }
1683 };
1684
1685 let enum_name = &input.ident;
1687 let enum_docs = extract_doc_comments(&input.attrs);
1688
1689 let mut intents_doc_lines = Vec::new();
1690
1691 if !enum_docs.is_empty() {
1693 intents_doc_lines.push(format!("{}: {}", enum_name, enum_docs));
1694 } else {
1695 intents_doc_lines.push(format!("{}:", enum_name));
1696 }
1697 intents_doc_lines.push(String::new()); intents_doc_lines.push("Possible values:".to_string());
1699
1700 for variant in &enum_data.variants {
1702 let variant_name = &variant.ident;
1703 let variant_docs = extract_doc_comments(&variant.attrs);
1704
1705 if !variant_docs.is_empty() {
1706 intents_doc_lines.push(format!("- {}: {}", variant_name, variant_docs));
1707 } else {
1708 intents_doc_lines.push(format!("- {}", variant_name));
1709 }
1710 }
1711
1712 let intents_doc_str = intents_doc_lines.join("\n");
1713
1714 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
1716 let user_variables: Vec<String> = placeholders
1717 .iter()
1718 .filter_map(|(name, _)| {
1719 if name != "intents_doc" {
1720 Some(name.clone())
1721 } else {
1722 None
1723 }
1724 })
1725 .collect();
1726
1727 let enum_name_str = enum_name.to_string();
1729 let snake_case_name = to_snake_case(&enum_name_str);
1730 let function_name = syn::Ident::new(
1731 &format!("build_{}_prompt", snake_case_name),
1732 proc_macro2::Span::call_site(),
1733 );
1734
1735 let function_params: Vec<proc_macro2::TokenStream> = user_variables
1737 .iter()
1738 .map(|var| {
1739 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
1740 quote! { #ident: &str }
1741 })
1742 .collect();
1743
1744 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
1746 .iter()
1747 .map(|var| {
1748 let var_str = var.clone();
1749 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
1750 quote! {
1751 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
1752 }
1753 })
1754 .collect();
1755
1756 let converted_template = prompt_template.clone();
1758
1759 let extractor_name = syn::Ident::new(
1761 &format!("{}Extractor", enum_name),
1762 proc_macro2::Span::call_site(),
1763 );
1764
1765 let filtered_attrs: Vec<_> = input
1767 .attrs
1768 .iter()
1769 .filter(|attr| !attr.path().is_ident("intent"))
1770 .collect();
1771
1772 let vis = &input.vis;
1774 let generics = &input.generics;
1775 let variants = &enum_data.variants;
1776 let enum_output = quote! {
1777 #(#filtered_attrs)*
1778 #vis enum #enum_name #generics {
1779 #variants
1780 }
1781 };
1782
1783 let expanded = quote! {
1785 #enum_output
1787
1788 pub fn #function_name(#(#function_params),*) -> String {
1790 let mut env = minijinja::Environment::new();
1791 env.add_template("prompt", #converted_template)
1792 .expect("Failed to parse intent prompt template");
1793
1794 let tmpl = env.get_template("prompt").unwrap();
1795
1796 let mut __template_context = std::collections::HashMap::new();
1797
1798 __template_context.insert("intents_doc".to_string(), minijinja::Value::from(#intents_doc_str));
1800
1801 #(#context_insertions)*
1803
1804 tmpl.render(&__template_context)
1805 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
1806 }
1807
1808 pub struct #extractor_name;
1810
1811 impl #extractor_name {
1812 pub const EXTRACTOR_TAG: &'static str = #extractor_tag;
1813 }
1814
1815 impl llm_toolkit::intent::IntentExtractor<#enum_name> for #extractor_name {
1816 fn extract_intent(&self, response: &str) -> Result<#enum_name, llm_toolkit::intent::IntentExtractionError> {
1817 llm_toolkit::intent::extract_intent_from_response(response, Self::EXTRACTOR_TAG)
1819 }
1820 }
1821 };
1822
1823 TokenStream::from(expanded)
1824}
1825
1826fn to_snake_case(s: &str) -> String {
1828 let mut result = String::new();
1829 let mut prev_upper = false;
1830
1831 for (i, ch) in s.chars().enumerate() {
1832 if ch.is_uppercase() {
1833 if i > 0 && !prev_upper {
1834 result.push('_');
1835 }
1836 result.push(ch.to_lowercase().next().unwrap());
1837 prev_upper = true;
1838 } else {
1839 result.push(ch);
1840 prev_upper = false;
1841 }
1842 }
1843
1844 result
1845}
1846
1847#[derive(Debug, Default)]
1849struct ActionAttrs {
1850 tag: Option<String>,
1851}
1852
1853fn parse_action_attrs(attrs: &[syn::Attribute]) -> ActionAttrs {
1854 let mut result = ActionAttrs::default();
1855
1856 for attr in attrs {
1857 if attr.path().is_ident("action")
1858 && let Ok(meta_list) = attr.meta.require_list()
1859 && let Ok(metas) =
1860 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1861 {
1862 for meta in metas {
1863 if let Meta::NameValue(nv) = meta
1864 && nv.path.is_ident("tag")
1865 && let syn::Expr::Lit(syn::ExprLit {
1866 lit: syn::Lit::Str(lit_str),
1867 ..
1868 }) = nv.value
1869 {
1870 result.tag = Some(lit_str.value());
1871 }
1872 }
1873 }
1874 }
1875
1876 result
1877}
1878
1879#[derive(Debug, Default)]
1881struct FieldActionAttrs {
1882 is_attribute: bool,
1883 is_inner_text: bool,
1884}
1885
1886fn parse_field_action_attrs(attrs: &[syn::Attribute]) -> FieldActionAttrs {
1887 let mut result = FieldActionAttrs::default();
1888
1889 for attr in attrs {
1890 if attr.path().is_ident("action")
1891 && let Ok(meta_list) = attr.meta.require_list()
1892 {
1893 let tokens_str = meta_list.tokens.to_string();
1894 if tokens_str == "attribute" {
1895 result.is_attribute = true;
1896 } else if tokens_str == "inner_text" {
1897 result.is_inner_text = true;
1898 }
1899 }
1900 }
1901
1902 result
1903}
1904
1905fn generate_multi_tag_actions_doc(
1907 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
1908) -> String {
1909 let mut doc_lines = Vec::new();
1910
1911 for variant in variants {
1912 let action_attrs = parse_action_attrs(&variant.attrs);
1913
1914 if let Some(tag) = action_attrs.tag {
1915 let variant_docs = extract_doc_comments(&variant.attrs);
1916
1917 match &variant.fields {
1918 syn::Fields::Unit => {
1919 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
1921 }
1922 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
1923 doc_lines.push(format!("- `<{}>...</{}>`: {}", tag, tag, variant_docs));
1925 }
1926 syn::Fields::Named(fields) => {
1927 let mut attrs_str = Vec::new();
1929 let mut has_inner_text = false;
1930
1931 for field in &fields.named {
1932 let field_name = field.ident.as_ref().unwrap();
1933 let field_attrs = parse_field_action_attrs(&field.attrs);
1934
1935 if field_attrs.is_attribute {
1936 attrs_str.push(format!("{}=\"...\"", field_name));
1937 } else if field_attrs.is_inner_text {
1938 has_inner_text = true;
1939 }
1940 }
1941
1942 let attrs_part = if !attrs_str.is_empty() {
1943 format!(" {}", attrs_str.join(" "))
1944 } else {
1945 String::new()
1946 };
1947
1948 if has_inner_text {
1949 doc_lines.push(format!(
1950 "- `<{}{}>...</{}>`: {}",
1951 tag, attrs_part, tag, variant_docs
1952 ));
1953 } else if !attrs_str.is_empty() {
1954 doc_lines.push(format!("- `<{}{} />`: {}", tag, attrs_part, variant_docs));
1955 } else {
1956 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
1957 }
1958
1959 for field in &fields.named {
1961 let field_name = field.ident.as_ref().unwrap();
1962 let field_attrs = parse_field_action_attrs(&field.attrs);
1963 let field_docs = extract_doc_comments(&field.attrs);
1964
1965 if field_attrs.is_attribute {
1966 doc_lines
1967 .push(format!(" - `{}` (attribute): {}", field_name, field_docs));
1968 } else if field_attrs.is_inner_text {
1969 doc_lines
1970 .push(format!(" - `{}` (inner_text): {}", field_name, field_docs));
1971 }
1972 }
1973 }
1974 _ => {
1975 }
1977 }
1978 }
1979 }
1980
1981 doc_lines.join("\n")
1982}
1983
1984fn generate_tags_regex(
1986 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
1987) -> String {
1988 let mut tag_names = Vec::new();
1989
1990 for variant in variants {
1991 let action_attrs = parse_action_attrs(&variant.attrs);
1992 if let Some(tag) = action_attrs.tag {
1993 tag_names.push(tag);
1994 }
1995 }
1996
1997 if tag_names.is_empty() {
1998 return String::new();
1999 }
2000
2001 let tags_pattern = tag_names.join("|");
2002 format!(
2005 r"(?is)<(?:{})\b[^>]*/>|<(?:{})\b[^>]*>.*?</(?:{})>",
2006 tags_pattern, tags_pattern, tags_pattern
2007 )
2008}
2009
2010fn generate_multi_tag_output(
2012 input: &DeriveInput,
2013 enum_name: &syn::Ident,
2014 enum_data: &syn::DataEnum,
2015 prompt_template: String,
2016 actions_doc: String,
2017) -> TokenStream {
2018 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2020 let user_variables: Vec<String> = placeholders
2021 .iter()
2022 .filter_map(|(name, _)| {
2023 if name != "actions_doc" {
2024 Some(name.clone())
2025 } else {
2026 None
2027 }
2028 })
2029 .collect();
2030
2031 let enum_name_str = enum_name.to_string();
2033 let snake_case_name = to_snake_case(&enum_name_str);
2034 let function_name = syn::Ident::new(
2035 &format!("build_{}_prompt", snake_case_name),
2036 proc_macro2::Span::call_site(),
2037 );
2038
2039 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2041 .iter()
2042 .map(|var| {
2043 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2044 quote! { #ident: &str }
2045 })
2046 .collect();
2047
2048 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2050 .iter()
2051 .map(|var| {
2052 let var_str = var.clone();
2053 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2054 quote! {
2055 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2056 }
2057 })
2058 .collect();
2059
2060 let extractor_name = syn::Ident::new(
2062 &format!("{}Extractor", enum_name),
2063 proc_macro2::Span::call_site(),
2064 );
2065
2066 let filtered_attrs: Vec<_> = input
2068 .attrs
2069 .iter()
2070 .filter(|attr| !attr.path().is_ident("intent"))
2071 .collect();
2072
2073 let filtered_variants: Vec<proc_macro2::TokenStream> = enum_data
2075 .variants
2076 .iter()
2077 .map(|variant| {
2078 let variant_name = &variant.ident;
2079 let variant_attrs: Vec<_> = variant
2080 .attrs
2081 .iter()
2082 .filter(|attr| !attr.path().is_ident("action"))
2083 .collect();
2084 let fields = &variant.fields;
2085
2086 let filtered_fields = match fields {
2088 syn::Fields::Named(named_fields) => {
2089 let filtered: Vec<_> = named_fields
2090 .named
2091 .iter()
2092 .map(|field| {
2093 let field_name = &field.ident;
2094 let field_type = &field.ty;
2095 let field_vis = &field.vis;
2096 let filtered_attrs: Vec<_> = field
2097 .attrs
2098 .iter()
2099 .filter(|attr| !attr.path().is_ident("action"))
2100 .collect();
2101 quote! {
2102 #(#filtered_attrs)*
2103 #field_vis #field_name: #field_type
2104 }
2105 })
2106 .collect();
2107 quote! { { #(#filtered,)* } }
2108 }
2109 syn::Fields::Unnamed(unnamed_fields) => {
2110 let types: Vec<_> = unnamed_fields
2111 .unnamed
2112 .iter()
2113 .map(|field| {
2114 let field_type = &field.ty;
2115 quote! { #field_type }
2116 })
2117 .collect();
2118 quote! { (#(#types),*) }
2119 }
2120 syn::Fields::Unit => quote! {},
2121 };
2122
2123 quote! {
2124 #(#variant_attrs)*
2125 #variant_name #filtered_fields
2126 }
2127 })
2128 .collect();
2129
2130 let vis = &input.vis;
2131 let generics = &input.generics;
2132
2133 let parsing_arms = generate_parsing_arms(&enum_data.variants, enum_name);
2135
2136 let tags_regex = generate_tags_regex(&enum_data.variants);
2138
2139 let expanded = quote! {
2140 #(#filtered_attrs)*
2142 #vis enum #enum_name #generics {
2143 #(#filtered_variants),*
2144 }
2145
2146 pub fn #function_name(#(#function_params),*) -> String {
2148 let mut env = minijinja::Environment::new();
2149 env.add_template("prompt", #prompt_template)
2150 .expect("Failed to parse intent prompt template");
2151
2152 let tmpl = env.get_template("prompt").unwrap();
2153
2154 let mut __template_context = std::collections::HashMap::new();
2155
2156 __template_context.insert("actions_doc".to_string(), minijinja::Value::from(#actions_doc));
2158
2159 #(#context_insertions)*
2161
2162 tmpl.render(&__template_context)
2163 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2164 }
2165
2166 pub struct #extractor_name;
2168
2169 impl #extractor_name {
2170 fn parse_single_action(&self, text: &str) -> Option<#enum_name> {
2171 use ::quick_xml::events::Event;
2172 use ::quick_xml::Reader;
2173
2174 let mut actions = Vec::new();
2175 let mut reader = Reader::from_str(text);
2176 reader.config_mut().trim_text(true);
2177
2178 let mut buf = Vec::new();
2179
2180 loop {
2181 match reader.read_event_into(&mut buf) {
2182 Ok(Event::Start(e)) => {
2183 let owned_e = e.into_owned();
2184 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2185 let is_empty = false;
2186
2187 #parsing_arms
2188 }
2189 Ok(Event::Empty(e)) => {
2190 let owned_e = e.into_owned();
2191 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2192 let is_empty = true;
2193
2194 #parsing_arms
2195 }
2196 Ok(Event::Eof) => break,
2197 Err(_) => {
2198 break;
2200 }
2201 _ => {}
2202 }
2203 buf.clear();
2204 }
2205
2206 actions.into_iter().next()
2207 }
2208
2209 pub fn extract_actions(&self, text: &str) -> Result<Vec<#enum_name>, llm_toolkit::intent::IntentError> {
2210 use ::quick_xml::events::Event;
2211 use ::quick_xml::Reader;
2212
2213 let mut actions = Vec::new();
2214 let mut reader = Reader::from_str(text);
2215 reader.config_mut().trim_text(true);
2216
2217 let mut buf = Vec::new();
2218
2219 loop {
2220 match reader.read_event_into(&mut buf) {
2221 Ok(Event::Start(e)) => {
2222 let owned_e = e.into_owned();
2223 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2224 let is_empty = false;
2225
2226 #parsing_arms
2227 }
2228 Ok(Event::Empty(e)) => {
2229 let owned_e = e.into_owned();
2230 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2231 let is_empty = true;
2232
2233 #parsing_arms
2234 }
2235 Ok(Event::Eof) => break,
2236 Err(_) => {
2237 break;
2239 }
2240 _ => {}
2241 }
2242 buf.clear();
2243 }
2244
2245 Ok(actions)
2246 }
2247
2248 pub fn transform_actions<F>(&self, text: &str, mut transformer: F) -> String
2249 where
2250 F: FnMut(#enum_name) -> String,
2251 {
2252 use ::regex::Regex;
2253
2254 let regex_pattern = #tags_regex;
2255 if regex_pattern.is_empty() {
2256 return text.to_string();
2257 }
2258
2259 let re = Regex::new(®ex_pattern).unwrap_or_else(|e| {
2260 panic!("Failed to compile regex for action tags: {}", e);
2261 });
2262
2263 re.replace_all(text, |caps: &::regex::Captures| {
2264 let matched = caps.get(0).map(|m| m.as_str()).unwrap_or("");
2265
2266 if let Some(action) = self.parse_single_action(matched) {
2268 transformer(action)
2269 } else {
2270 matched.to_string()
2272 }
2273 }).to_string()
2274 }
2275
2276 pub fn strip_actions(&self, text: &str) -> String {
2277 self.transform_actions(text, |_| String::new())
2278 }
2279 }
2280 };
2281
2282 TokenStream::from(expanded)
2283}
2284
2285fn generate_parsing_arms(
2287 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2288 enum_name: &syn::Ident,
2289) -> proc_macro2::TokenStream {
2290 let mut arms = Vec::new();
2291
2292 for variant in variants {
2293 let variant_name = &variant.ident;
2294 let action_attrs = parse_action_attrs(&variant.attrs);
2295
2296 if let Some(tag) = action_attrs.tag {
2297 match &variant.fields {
2298 syn::Fields::Unit => {
2299 arms.push(quote! {
2301 if &tag_name == #tag {
2302 actions.push(#enum_name::#variant_name);
2303 }
2304 });
2305 }
2306 syn::Fields::Unnamed(_fields) => {
2307 arms.push(quote! {
2309 if &tag_name == #tag && !is_empty {
2310 match reader.read_text(owned_e.name()) {
2312 Ok(text) => {
2313 actions.push(#enum_name::#variant_name(text.to_string()));
2314 }
2315 Err(_) => {
2316 actions.push(#enum_name::#variant_name(String::new()));
2318 }
2319 }
2320 }
2321 });
2322 }
2323 syn::Fields::Named(fields) => {
2324 let mut field_names = Vec::new();
2326 let mut has_inner_text_field = None;
2327
2328 for field in &fields.named {
2329 let field_name = field.ident.as_ref().unwrap();
2330 let field_attrs = parse_field_action_attrs(&field.attrs);
2331
2332 if field_attrs.is_attribute {
2333 field_names.push(field_name.clone());
2334 } else if field_attrs.is_inner_text {
2335 has_inner_text_field = Some(field_name.clone());
2336 }
2337 }
2338
2339 if let Some(inner_text_field) = has_inner_text_field {
2340 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2343 quote! {
2344 let mut #field_name = String::new();
2345 for attr in owned_e.attributes() {
2346 if let Ok(attr) = attr {
2347 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2348 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2349 break;
2350 }
2351 }
2352 }
2353 }
2354 }).collect();
2355
2356 arms.push(quote! {
2357 if &tag_name == #tag {
2358 #(#attr_extractions)*
2359
2360 if is_empty {
2362 let #inner_text_field = String::new();
2363 actions.push(#enum_name::#variant_name {
2364 #(#field_names,)*
2365 #inner_text_field,
2366 });
2367 } else {
2368 match reader.read_text(owned_e.name()) {
2370 Ok(text) => {
2371 let #inner_text_field = text.to_string();
2372 actions.push(#enum_name::#variant_name {
2373 #(#field_names,)*
2374 #inner_text_field,
2375 });
2376 }
2377 Err(_) => {
2378 let #inner_text_field = String::new();
2380 actions.push(#enum_name::#variant_name {
2381 #(#field_names,)*
2382 #inner_text_field,
2383 });
2384 }
2385 }
2386 }
2387 }
2388 });
2389 } else {
2390 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2392 quote! {
2393 let mut #field_name = String::new();
2394 for attr in owned_e.attributes() {
2395 if let Ok(attr) = attr {
2396 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2397 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2398 break;
2399 }
2400 }
2401 }
2402 }
2403 }).collect();
2404
2405 arms.push(quote! {
2406 if &tag_name == #tag {
2407 #(#attr_extractions)*
2408 actions.push(#enum_name::#variant_name {
2409 #(#field_names),*
2410 });
2411 }
2412 });
2413 }
2414 }
2415 }
2416 }
2417 }
2418
2419 quote! {
2420 #(#arms)*
2421 }
2422}
2423
2424#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
2426pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
2427 let input = parse_macro_input!(input as DeriveInput);
2428
2429 let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
2431
2432 let struct_name = &input.ident;
2433 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2434
2435 let placeholders = parse_template_placeholders_with_mode(&template);
2437
2438 let mut converted_template = template.clone();
2440 let mut context_fields = Vec::new();
2441
2442 let fields = match &input.data {
2444 Data::Struct(data_struct) => match &data_struct.fields {
2445 syn::Fields::Named(fields) => &fields.named,
2446 _ => panic!("ToPromptFor is only supported for structs with named fields"),
2447 },
2448 _ => panic!("ToPromptFor is only supported for structs"),
2449 };
2450
2451 let has_mode_support = input.attrs.iter().any(|attr| {
2453 if attr.path().is_ident("prompt")
2454 && let Ok(metas) =
2455 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2456 {
2457 for meta in metas {
2458 if let Meta::NameValue(nv) = meta
2459 && nv.path.is_ident("mode")
2460 {
2461 return true;
2462 }
2463 }
2464 }
2465 false
2466 });
2467
2468 for (placeholder_name, mode_opt) in &placeholders {
2470 if placeholder_name == "self" {
2471 if let Some(specific_mode) = mode_opt {
2472 let unique_key = format!("self__{}", specific_mode);
2474
2475 let pattern = format!("{{{{ self:{} }}}}", specific_mode);
2477 let replacement = format!("{{{{ {} }}}}", unique_key);
2478 converted_template = converted_template.replace(&pattern, &replacement);
2479
2480 context_fields.push(quote! {
2482 context.insert(
2483 #unique_key.to_string(),
2484 minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
2485 );
2486 });
2487 } else {
2488 if has_mode_support {
2491 context_fields.push(quote! {
2493 context.insert(
2494 "self".to_string(),
2495 minijinja::Value::from(self.to_prompt_with_mode(mode))
2496 );
2497 });
2498 } else {
2499 context_fields.push(quote! {
2501 context.insert(
2502 "self".to_string(),
2503 minijinja::Value::from(self.to_prompt())
2504 );
2505 });
2506 }
2507 }
2508 } else {
2509 let field_exists = fields.iter().any(|f| {
2512 f.ident
2513 .as_ref()
2514 .is_some_and(|ident| ident == placeholder_name)
2515 });
2516
2517 if field_exists {
2518 let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
2519
2520 context_fields.push(quote! {
2524 context.insert(
2525 #placeholder_name.to_string(),
2526 minijinja::Value::from_serialize(&self.#field_ident)
2527 );
2528 });
2529 }
2530 }
2532 }
2533
2534 let expanded = quote! {
2535 impl #impl_generics llm_toolkit::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
2536 where
2537 #target_type: serde::Serialize,
2538 {
2539 fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
2540 let mut env = minijinja::Environment::new();
2542 env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
2543 panic!("Failed to parse template: {}", e)
2544 });
2545
2546 let tmpl = env.get_template("prompt").unwrap();
2547
2548 let mut context = std::collections::HashMap::new();
2550 context.insert(
2552 "self".to_string(),
2553 minijinja::Value::from_serialize(self)
2554 );
2555 context.insert(
2557 "target".to_string(),
2558 minijinja::Value::from_serialize(target)
2559 );
2560 #(#context_fields)*
2561
2562 tmpl.render(context).unwrap_or_else(|e| {
2564 format!("Failed to render prompt: {}", e)
2565 })
2566 }
2567 }
2568 };
2569
2570 TokenStream::from(expanded)
2571}