1use proc_macro::TokenStream;
2use proc_macro_crate::{FoundCrate, crate_name};
3use quote::quote;
4use regex::Regex;
5use syn::{
6 Data, DeriveInput, Meta, Token,
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 punctuated::Punctuated,
10};
11
12fn parse_template_placeholders_with_mode(template: &str) -> Vec<(String, Option<String>)> {
15 let mut placeholders = Vec::new();
16 let mut seen_fields = std::collections::HashSet::new();
17
18 let mode_pattern = Regex::new(r"\{\{\s*(\w+)\s*:\s*(\w+)\s*\}\}").unwrap();
20 for cap in mode_pattern.captures_iter(template) {
21 let field_name = cap[1].to_string();
22 let mode = cap[2].to_string();
23 placeholders.push((field_name.clone(), Some(mode)));
24 seen_fields.insert(field_name);
25 }
26
27 let standard_pattern = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
29 for cap in standard_pattern.captures_iter(template) {
30 let field_name = cap[1].to_string();
31 if !seen_fields.contains(&field_name) {
33 placeholders.push((field_name, None));
34 }
35 }
36
37 placeholders
38}
39
40fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
42 attrs
43 .iter()
44 .filter_map(|attr| {
45 if attr.path().is_ident("doc")
46 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
47 && let syn::Expr::Lit(syn::ExprLit {
48 lit: syn::Lit::Str(lit_str),
49 ..
50 }) = &meta_name_value.value
51 {
52 return Some(lit_str.value());
53 }
54 None
55 })
56 .map(|s| s.trim().to_string())
57 .collect::<Vec<_>>()
58 .join(" ")
59}
60
61fn generate_example_only_parts(
63 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
64 has_default: bool,
65 crate_path: &proc_macro2::TokenStream,
66) -> proc_macro2::TokenStream {
67 let mut field_values = Vec::new();
68
69 for field in fields.iter() {
70 let field_name = field.ident.as_ref().unwrap();
71 let field_name_str = field_name.to_string();
72 let attrs = parse_field_prompt_attrs(&field.attrs);
73
74 if attrs.skip {
76 continue;
77 }
78
79 if let Some(example) = attrs.example {
81 field_values.push(quote! {
83 json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
84 });
85 } else if has_default {
86 field_values.push(quote! {
88 let default_value = serde_json::to_value(&default_instance.#field_name)
89 .unwrap_or(serde_json::Value::Null);
90 json_obj.insert(#field_name_str.to_string(), default_value);
91 });
92 } else {
93 field_values.push(quote! {
95 let value = serde_json::to_value(&self.#field_name)
96 .unwrap_or(serde_json::Value::Null);
97 json_obj.insert(#field_name_str.to_string(), value);
98 });
99 }
100 }
101
102 if has_default {
103 quote! {
104 {
105 let default_instance = Self::default();
106 let mut json_obj = serde_json::Map::new();
107 #(#field_values)*
108 let json_value = serde_json::Value::Object(json_obj);
109 let json_str = serde_json::to_string_pretty(&json_value)
110 .unwrap_or_else(|_| "{}".to_string());
111 vec![#crate_path::prompt::PromptPart::Text(json_str)]
112 }
113 }
114 } else {
115 quote! {
116 {
117 let mut json_obj = serde_json::Map::new();
118 #(#field_values)*
119 let json_value = serde_json::Value::Object(json_obj);
120 let json_str = serde_json::to_string_pretty(&json_value)
121 .unwrap_or_else(|_| "{}".to_string());
122 vec![#crate_path::prompt::PromptPart::Text(json_str)]
123 }
124 }
125 }
126}
127
128fn generate_schema_only_parts(
130 struct_name: &str,
131 struct_docs: &str,
132 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
133 crate_path: &proc_macro2::TokenStream,
134) -> proc_macro2::TokenStream {
135 let mut schema_lines = vec![];
136
137 if !struct_docs.is_empty() {
139 schema_lines.push(format!("### Schema for `{}`\n{}", struct_name, struct_docs));
140 } else {
141 schema_lines.push(format!("### Schema for `{}`", struct_name));
142 }
143
144 schema_lines.push("{".to_string());
145
146 for (i, field) in fields.iter().enumerate() {
148 let field_name = field.ident.as_ref().unwrap();
149 let attrs = parse_field_prompt_attrs(&field.attrs);
150
151 if attrs.skip {
153 continue;
154 }
155
156 let field_docs = extract_doc_comments(&field.attrs);
158
159 let type_str = format_type_for_schema(&field.ty);
161
162 let mut field_line = format!(" \"{}\": \"{}\"", field_name, type_str);
164
165 if !field_docs.is_empty() {
167 field_line.push_str(&format!(", // {}", field_docs));
168 }
169
170 let remaining_fields = fields
172 .iter()
173 .skip(i + 1)
174 .filter(|f| {
175 let attrs = parse_field_prompt_attrs(&f.attrs);
176 !attrs.skip
177 })
178 .count();
179
180 if remaining_fields > 0 {
181 field_line.push(',');
182 }
183
184 schema_lines.push(field_line);
185 }
186
187 schema_lines.push("}".to_string());
188
189 let schema_str = schema_lines.join("\n");
190
191 quote! {
192 vec![#crate_path::prompt::PromptPart::Text(#schema_str.to_string())]
193 }
194}
195
196fn format_type_for_schema(ty: &syn::Type) -> String {
198 match ty {
200 syn::Type::Path(type_path) => {
201 let path = &type_path.path;
202 if let Some(last_segment) = path.segments.last() {
203 let type_name = last_segment.ident.to_string();
204
205 if type_name == "Option"
207 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
208 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
209 {
210 return format!("{} | null", format_type_for_schema(inner_type));
211 }
212
213 match type_name.as_str() {
215 "String" | "str" => "string".to_string(),
216 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
217 | "u64" | "u128" | "usize" => "number".to_string(),
218 "f32" | "f64" => "number".to_string(),
219 "bool" => "boolean".to_string(),
220 "Vec" => {
221 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
222 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
223 {
224 return format!("{}[]", format_type_for_schema(inner_type));
225 }
226 "array".to_string()
227 }
228 _ => type_name.to_lowercase(),
229 }
230 } else {
231 "unknown".to_string()
232 }
233 }
234 _ => "unknown".to_string(),
235 }
236}
237
238enum PromptAttribute {
240 Skip,
241 Description(String),
242 None,
243}
244
245fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
247 for attr in attrs {
248 if attr.path().is_ident("prompt") {
249 if let Ok(meta_list) = attr.meta.require_list() {
251 let tokens = &meta_list.tokens;
252 let tokens_str = tokens.to_string();
253 if tokens_str == "skip" {
254 return PromptAttribute::Skip;
255 }
256 }
257
258 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
260 return PromptAttribute::Description(lit_str.value());
261 }
262 }
263 }
264 PromptAttribute::None
265}
266
267#[derive(Debug, Default)]
269struct FieldPromptAttrs {
270 skip: bool,
271 rename: Option<String>,
272 format_with: Option<String>,
273 image: bool,
274 example: Option<String>,
275}
276
277fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
279 let mut result = FieldPromptAttrs::default();
280
281 for attr in attrs {
282 if attr.path().is_ident("prompt") {
283 if let Ok(meta_list) = attr.meta.require_list() {
285 if let Ok(metas) =
287 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
288 {
289 for meta in metas {
290 match meta {
291 Meta::Path(path) if path.is_ident("skip") => {
292 result.skip = true;
293 }
294 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
295 if let syn::Expr::Lit(syn::ExprLit {
296 lit: syn::Lit::Str(lit_str),
297 ..
298 }) = nv.value
299 {
300 result.rename = Some(lit_str.value());
301 }
302 }
303 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
304 if let syn::Expr::Lit(syn::ExprLit {
305 lit: syn::Lit::Str(lit_str),
306 ..
307 }) = nv.value
308 {
309 result.format_with = Some(lit_str.value());
310 }
311 }
312 Meta::Path(path) if path.is_ident("image") => {
313 result.image = true;
314 }
315 Meta::NameValue(nv) if nv.path.is_ident("example") => {
316 if let syn::Expr::Lit(syn::ExprLit {
317 lit: syn::Lit::Str(lit_str),
318 ..
319 }) = nv.value
320 {
321 result.example = Some(lit_str.value());
322 }
323 }
324 _ => {}
325 }
326 }
327 } else if meta_list.tokens.to_string() == "skip" {
328 result.skip = true;
330 } else if meta_list.tokens.to_string() == "image" {
331 result.image = true;
333 }
334 }
335 }
336 }
337
338 result
339}
340
341#[proc_macro_derive(ToPrompt, attributes(prompt))]
384pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
385 let input = parse_macro_input!(input as DeriveInput);
386
387 let found_crate =
388 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
389 let crate_path = match found_crate {
390 FoundCrate::Itself => {
391 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
393 quote!(::#ident)
394 }
395 FoundCrate::Name(name) => {
396 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
397 quote!(::#ident)
398 }
399 };
400
401 match &input.data {
403 Data::Enum(data_enum) => {
404 let enum_name = &input.ident;
406 let enum_docs = extract_doc_comments(&input.attrs);
407
408 let mut prompt_lines = Vec::new();
409
410 if !enum_docs.is_empty() {
412 prompt_lines.push(format!("{}: {}", enum_name, enum_docs));
413 } else {
414 prompt_lines.push(format!("{}:", enum_name));
415 }
416 prompt_lines.push(String::new()); prompt_lines.push("Possible values:".to_string());
418
419 for variant in &data_enum.variants {
421 let variant_name = &variant.ident;
422
423 match parse_prompt_attribute(&variant.attrs) {
425 PromptAttribute::Skip => {
426 continue;
428 }
429 PromptAttribute::Description(desc) => {
430 prompt_lines.push(format!("- {}: {}", variant_name, desc));
432 }
433 PromptAttribute::None => {
434 let variant_docs = extract_doc_comments(&variant.attrs);
436 if !variant_docs.is_empty() {
437 prompt_lines.push(format!("- {}: {}", variant_name, variant_docs));
438 } else {
439 prompt_lines.push(format!("- {}", variant_name));
440 }
441 }
442 }
443 }
444
445 let prompt_string = prompt_lines.join("\n");
446 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
447
448 let mut match_arms = Vec::new();
450 for variant in &data_enum.variants {
451 let variant_name = &variant.ident;
452
453 match parse_prompt_attribute(&variant.attrs) {
455 PromptAttribute::Skip => {
456 match_arms.push(quote! {
458 Self::#variant_name => stringify!(#variant_name).to_string()
459 });
460 }
461 PromptAttribute::Description(desc) => {
462 match_arms.push(quote! {
464 Self::#variant_name => format!("{}: {}", stringify!(#variant_name), #desc)
465 });
466 }
467 PromptAttribute::None => {
468 let variant_docs = extract_doc_comments(&variant.attrs);
470 if !variant_docs.is_empty() {
471 match_arms.push(quote! {
472 Self::#variant_name => format!("{}: {}", stringify!(#variant_name), #variant_docs)
473 });
474 } else {
475 match_arms.push(quote! {
476 Self::#variant_name => stringify!(#variant_name).to_string()
477 });
478 }
479 }
480 }
481 }
482
483 let to_prompt_impl = if match_arms.is_empty() {
484 quote! {
486 fn to_prompt(&self) -> String {
487 match *self {}
488 }
489 }
490 } else {
491 quote! {
492 fn to_prompt(&self) -> String {
493 match self {
494 #(#match_arms),*
495 }
496 }
497 }
498 };
499
500 let expanded = quote! {
501 impl #impl_generics #crate_path::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
502 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
503 vec![#crate_path::prompt::PromptPart::Text(self.to_prompt())]
504 }
505
506 #to_prompt_impl
507
508 fn prompt_schema() -> String {
509 #prompt_string.to_string()
510 }
511 }
512 };
513
514 TokenStream::from(expanded)
515 }
516 Data::Struct(data_struct) => {
517 let mut template_attr = None;
519 let mut template_file_attr = None;
520 let mut mode_attr = None;
521 let mut validate_attr = false;
522
523 for attr in &input.attrs {
524 if attr.path().is_ident("prompt") {
525 if let Ok(metas) =
527 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
528 {
529 for meta in metas {
530 match meta {
531 Meta::NameValue(nv) if nv.path.is_ident("template") => {
532 if let syn::Expr::Lit(expr_lit) = nv.value
533 && let syn::Lit::Str(lit_str) = expr_lit.lit
534 {
535 template_attr = Some(lit_str.value());
536 }
537 }
538 Meta::NameValue(nv) if nv.path.is_ident("template_file") => {
539 if let syn::Expr::Lit(expr_lit) = nv.value
540 && let syn::Lit::Str(lit_str) = expr_lit.lit
541 {
542 template_file_attr = Some(lit_str.value());
543 }
544 }
545 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
546 if let syn::Expr::Lit(expr_lit) = nv.value
547 && let syn::Lit::Str(lit_str) = expr_lit.lit
548 {
549 mode_attr = Some(lit_str.value());
550 }
551 }
552 Meta::NameValue(nv) if nv.path.is_ident("validate") => {
553 if let syn::Expr::Lit(expr_lit) = nv.value
554 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
555 {
556 validate_attr = lit_bool.value();
557 }
558 }
559 _ => {}
560 }
561 }
562 }
563 }
564 }
565
566 if template_attr.is_some() && template_file_attr.is_some() {
568 return syn::Error::new(
569 input.ident.span(),
570 "The `template` and `template_file` attributes are mutually exclusive. Please use only one.",
571 ).to_compile_error().into();
572 }
573
574 let template_str = if let Some(file_path) = template_file_attr {
576 let mut full_path = None;
580
581 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
583 let is_trybuild = manifest_dir.contains("target/tests/trybuild");
585
586 if !is_trybuild {
587 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
589 if candidate.exists() {
590 full_path = Some(candidate);
591 }
592 } else {
593 if let Some(target_pos) = manifest_dir.find("/target/tests/trybuild") {
599 let workspace_root = &manifest_dir[..target_pos];
600 let original_macros_dir = std::path::Path::new(workspace_root)
602 .join("crates")
603 .join("llm-toolkit-macros");
604
605 let candidate = original_macros_dir.join(&file_path);
606 if candidate.exists() {
607 full_path = Some(candidate);
608 }
609 }
610 }
611 }
612
613 if full_path.is_none() {
615 let candidate = std::path::Path::new(&file_path).to_path_buf();
616 if candidate.exists() {
617 full_path = Some(candidate);
618 }
619 }
620
621 if full_path.is_none()
624 && let Ok(current_dir) = std::env::current_dir()
625 {
626 let mut search_dir = current_dir.as_path();
627 for _ in 0..10 {
629 let macros_dir = search_dir.join("crates/llm-toolkit-macros");
631 if macros_dir.exists() {
632 let candidate = macros_dir.join(&file_path);
633 if candidate.exists() {
634 full_path = Some(candidate);
635 break;
636 }
637 }
638 let candidate = search_dir.join(&file_path);
640 if candidate.exists() {
641 full_path = Some(candidate);
642 break;
643 }
644 if let Some(parent) = search_dir.parent() {
645 search_dir = parent;
646 } else {
647 break;
648 }
649 }
650 }
651
652 if full_path.is_none() {
654 let mut error_msg = format!(
656 "Template file '{}' not found at compile time.\n\nSearched in:",
657 file_path
658 );
659
660 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
661 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
662 error_msg.push_str(&format!("\n - {}", candidate.display()));
663 }
664
665 if let Ok(current_dir) = std::env::current_dir() {
666 let candidate = current_dir.join(&file_path);
667 error_msg.push_str(&format!("\n - {}", candidate.display()));
668 }
669
670 error_msg.push_str("\n\nPlease ensure:");
671 error_msg.push_str("\n 1. The template file exists");
672 error_msg.push_str("\n 2. The path is relative to CARGO_MANIFEST_DIR");
673 error_msg.push_str("\n 3. There are no typos in the path");
674
675 return syn::Error::new(input.ident.span(), error_msg)
676 .to_compile_error()
677 .into();
678 }
679
680 let final_path = full_path.unwrap();
681
682 match std::fs::read_to_string(&final_path) {
684 Ok(content) => Some(content),
685 Err(e) => {
686 return syn::Error::new(
687 input.ident.span(),
688 format!(
689 "Failed to read template file '{}': {}\n\nPath resolved to: {}",
690 file_path,
691 e,
692 final_path.display()
693 ),
694 )
695 .to_compile_error()
696 .into();
697 }
698 }
699 } else {
700 template_attr
701 };
702
703 if validate_attr && let Some(template) = &template_str {
705 let mut env = minijinja::Environment::new();
707 if let Err(e) = env.add_template("validation", template) {
708 let warning_msg =
710 format!("Template validation warning: Invalid Jinja syntax - {}", e);
711 let warning_ident = syn::Ident::new(
712 "TEMPLATE_VALIDATION_WARNING",
713 proc_macro2::Span::call_site(),
714 );
715 let _warning_tokens = quote! {
716 #[deprecated(note = #warning_msg)]
717 const #warning_ident: () = ();
718 let _ = #warning_ident;
719 };
720 eprintln!("cargo:warning={}", warning_msg);
722 }
723
724 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
726 &fields.named
727 } else {
728 panic!("Template validation is only supported for structs with named fields.");
729 };
730
731 let field_names: std::collections::HashSet<String> = fields
732 .iter()
733 .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
734 .collect();
735
736 let placeholders = parse_template_placeholders_with_mode(template);
738
739 for (placeholder_name, _mode) in &placeholders {
740 if placeholder_name != "self" && !field_names.contains(placeholder_name) {
741 let warning_msg = format!(
742 "Template validation warning: Variable '{}' used in template but not found in struct fields",
743 placeholder_name
744 );
745 eprintln!("cargo:warning={}", warning_msg);
746 }
747 }
748 }
749
750 let name = input.ident;
751 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
752
753 let struct_docs = extract_doc_comments(&input.attrs);
755
756 let is_mode_based =
758 mode_attr.is_some() || (template_str.is_none() && struct_docs.contains("mode"));
759
760 let expanded = if is_mode_based || mode_attr.is_some() {
761 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
763 &fields.named
764 } else {
765 panic!(
766 "Mode-based prompt generation is only supported for structs with named fields."
767 );
768 };
769
770 let struct_name_str = name.to_string();
771
772 let has_default = input.attrs.iter().any(|attr| {
774 if attr.path().is_ident("derive")
775 && let Ok(meta_list) = attr.meta.require_list()
776 {
777 let tokens_str = meta_list.tokens.to_string();
778 tokens_str.contains("Default")
779 } else {
780 false
781 }
782 });
783
784 let schema_parts =
786 generate_schema_only_parts(&struct_name_str, &struct_docs, fields, &crate_path);
787
788 let example_parts = generate_example_only_parts(fields, has_default, &crate_path);
790
791 quote! {
792 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
793 fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<#crate_path::prompt::PromptPart> {
794 match mode {
795 "schema_only" => #schema_parts,
796 "example_only" => #example_parts,
797 "full" | _ => {
798 let mut parts = Vec::new();
800
801 let schema_parts = #schema_parts;
803 parts.extend(schema_parts);
804
805 parts.push(#crate_path::prompt::PromptPart::Text("\n### Example".to_string()));
807 parts.push(#crate_path::prompt::PromptPart::Text(
808 format!("Here is an example of a valid `{}` object:", #struct_name_str)
809 ));
810
811 let example_parts = #example_parts;
813 parts.extend(example_parts);
814
815 parts
816 }
817 }
818 }
819
820 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
821 self.to_prompt_parts_with_mode("full")
822 }
823
824 fn to_prompt(&self) -> String {
825 self.to_prompt_parts()
826 .into_iter()
827 .filter_map(|part| match part {
828 #crate_path::prompt::PromptPart::Text(text) => Some(text),
829 _ => None,
830 })
831 .collect::<Vec<_>>()
832 .join("\n")
833 }
834
835 fn prompt_schema() -> String {
836 let schema_parts = #schema_parts;
837 schema_parts
838 .into_iter()
839 .filter_map(|part| match part {
840 #crate_path::prompt::PromptPart::Text(text) => Some(text),
841 _ => None,
842 })
843 .collect::<Vec<_>>()
844 .join("\n")
845 }
846 }
847 }
848 } else if let Some(template) = template_str {
849 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
852 &fields.named
853 } else {
854 panic!(
855 "Template prompt generation is only supported for structs with named fields."
856 );
857 };
858
859 let placeholders = parse_template_placeholders_with_mode(&template);
861 let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
863 mode.is_some()
864 && fields
865 .iter()
866 .any(|f| f.ident.as_ref().unwrap() == field_name)
867 });
868
869 let mut image_field_parts = Vec::new();
870 for f in fields.iter() {
871 let field_name = f.ident.as_ref().unwrap();
872 let attrs = parse_field_prompt_attrs(&f.attrs);
873
874 if attrs.image {
875 image_field_parts.push(quote! {
877 parts.extend(self.#field_name.to_prompt_parts());
878 });
879 }
880 }
881
882 if has_mode_syntax {
884 let mut context_fields = Vec::new();
886 let mut modified_template = template.clone();
887
888 for (field_name, mode_opt) in &placeholders {
890 if let Some(mode) = mode_opt {
891 let unique_key = format!("{}__{}", field_name, mode);
893
894 let pattern = format!("{{{{ {}:{} }}}}", field_name, mode);
896 let replacement = format!("{{{{ {} }}}}", unique_key);
897 modified_template = modified_template.replace(&pattern, &replacement);
898
899 let field_ident =
901 syn::Ident::new(field_name, proc_macro2::Span::call_site());
902
903 context_fields.push(quote! {
905 context.insert(
906 #unique_key.to_string(),
907 minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
908 );
909 });
910 }
911 }
912
913 for field in fields.iter() {
915 let field_name = field.ident.as_ref().unwrap();
916 let field_name_str = field_name.to_string();
917
918 let has_mode_entry = placeholders
920 .iter()
921 .any(|(name, mode)| name == &field_name_str && mode.is_some());
922
923 if !has_mode_entry {
924 let is_primitive = match &field.ty {
927 syn::Type::Path(type_path) => {
928 if let Some(segment) = type_path.path.segments.last() {
929 let type_name = segment.ident.to_string();
930 matches!(
931 type_name.as_str(),
932 "String"
933 | "str"
934 | "i8"
935 | "i16"
936 | "i32"
937 | "i64"
938 | "i128"
939 | "isize"
940 | "u8"
941 | "u16"
942 | "u32"
943 | "u64"
944 | "u128"
945 | "usize"
946 | "f32"
947 | "f64"
948 | "bool"
949 | "char"
950 )
951 } else {
952 false
953 }
954 }
955 _ => false,
956 };
957
958 if is_primitive {
959 context_fields.push(quote! {
960 context.insert(
961 #field_name_str.to_string(),
962 minijinja::Value::from_serialize(&self.#field_name)
963 );
964 });
965 } else {
966 context_fields.push(quote! {
968 context.insert(
969 #field_name_str.to_string(),
970 minijinja::Value::from(self.#field_name.to_prompt())
971 );
972 });
973 }
974 }
975 }
976
977 quote! {
978 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
979 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
980 let mut parts = Vec::new();
981
982 #(#image_field_parts)*
984
985 let text = {
987 let mut env = minijinja::Environment::new();
988 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
989 panic!("Failed to parse template: {}", e)
990 });
991
992 let tmpl = env.get_template("prompt").unwrap();
993
994 let mut context = std::collections::HashMap::new();
995 #(#context_fields)*
996
997 tmpl.render(context).unwrap_or_else(|e| {
998 format!("Failed to render prompt: {}", e)
999 })
1000 };
1001
1002 if !text.is_empty() {
1003 parts.push(#crate_path::prompt::PromptPart::Text(text));
1004 }
1005
1006 parts
1007 }
1008
1009 fn to_prompt(&self) -> String {
1010 let mut env = minijinja::Environment::new();
1012 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1013 panic!("Failed to parse template: {}", e)
1014 });
1015
1016 let tmpl = env.get_template("prompt").unwrap();
1017
1018 let mut context = std::collections::HashMap::new();
1019 #(#context_fields)*
1020
1021 tmpl.render(context).unwrap_or_else(|e| {
1022 format!("Failed to render prompt: {}", e)
1023 })
1024 }
1025
1026 fn prompt_schema() -> String {
1027 String::new() }
1029 }
1030 }
1031 } else {
1032 quote! {
1034 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1035 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1036 let mut parts = Vec::new();
1037
1038 #(#image_field_parts)*
1040
1041 let text = #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1043 format!("Failed to render prompt: {}", e)
1044 });
1045 if !text.is_empty() {
1046 parts.push(#crate_path::prompt::PromptPart::Text(text));
1047 }
1048
1049 parts
1050 }
1051
1052 fn to_prompt(&self) -> String {
1053 #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1054 format!("Failed to render prompt: {}", e)
1055 })
1056 }
1057
1058 fn prompt_schema() -> String {
1059 String::new() }
1061 }
1062 }
1063 }
1064 } else {
1065 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1068 &fields.named
1069 } else {
1070 panic!(
1071 "Default prompt generation is only supported for structs with named fields."
1072 );
1073 };
1074
1075 let mut text_field_parts = Vec::new();
1077 let mut image_field_parts = Vec::new();
1078
1079 for f in fields.iter() {
1080 let field_name = f.ident.as_ref().unwrap();
1081 let attrs = parse_field_prompt_attrs(&f.attrs);
1082
1083 if attrs.skip {
1085 continue;
1086 }
1087
1088 if attrs.image {
1089 image_field_parts.push(quote! {
1091 parts.extend(self.#field_name.to_prompt_parts());
1092 });
1093 } else {
1094 let key = if let Some(rename) = attrs.rename {
1100 rename
1101 } else {
1102 let doc_comment = extract_doc_comments(&f.attrs);
1103 if !doc_comment.is_empty() {
1104 doc_comment
1105 } else {
1106 field_name.to_string()
1107 }
1108 };
1109
1110 let value_expr = if let Some(format_with) = attrs.format_with {
1112 let func_path: syn::Path =
1114 syn::parse_str(&format_with).unwrap_or_else(|_| {
1115 panic!("Invalid function path: {}", format_with)
1116 });
1117 quote! { #func_path(&self.#field_name) }
1118 } else {
1119 quote! { self.#field_name.to_prompt() }
1120 };
1121
1122 text_field_parts.push(quote! {
1123 text_parts.push(format!("{}: {}", #key, #value_expr));
1124 });
1125 }
1126 }
1127
1128 quote! {
1130 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1131 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1132 let mut parts = Vec::new();
1133
1134 #(#image_field_parts)*
1136
1137 let mut text_parts = Vec::new();
1139 #(#text_field_parts)*
1140
1141 if !text_parts.is_empty() {
1142 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1143 }
1144
1145 parts
1146 }
1147
1148 fn to_prompt(&self) -> String {
1149 let mut text_parts = Vec::new();
1150 #(#text_field_parts)*
1151 text_parts.join("\n")
1152 }
1153
1154 fn prompt_schema() -> String {
1155 String::new() }
1157 }
1158 }
1159 };
1160
1161 TokenStream::from(expanded)
1162 }
1163 Data::Union(_) => {
1164 panic!("`#[derive(ToPrompt)]` is not supported for unions");
1165 }
1166 }
1167}
1168
1169#[derive(Debug, Clone)]
1171struct TargetInfo {
1172 name: String,
1173 template: Option<String>,
1174 field_configs: std::collections::HashMap<String, FieldTargetConfig>,
1175}
1176
1177#[derive(Debug, Clone, Default)]
1179struct FieldTargetConfig {
1180 skip: bool,
1181 rename: Option<String>,
1182 format_with: Option<String>,
1183 image: bool,
1184 include_only: bool, }
1186
1187fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
1189 let mut configs = Vec::new();
1190
1191 for attr in attrs {
1192 if attr.path().is_ident("prompt_for")
1193 && let Ok(meta_list) = attr.meta.require_list()
1194 {
1195 if meta_list.tokens.to_string() == "skip" {
1197 let config = FieldTargetConfig {
1199 skip: true,
1200 ..Default::default()
1201 };
1202 configs.push(("*".to_string(), config));
1203 } else if let Ok(metas) =
1204 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1205 {
1206 let mut target_name = None;
1207 let mut config = FieldTargetConfig::default();
1208
1209 for meta in metas {
1210 match meta {
1211 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1212 if let syn::Expr::Lit(syn::ExprLit {
1213 lit: syn::Lit::Str(lit_str),
1214 ..
1215 }) = nv.value
1216 {
1217 target_name = Some(lit_str.value());
1218 }
1219 }
1220 Meta::Path(path) if path.is_ident("skip") => {
1221 config.skip = true;
1222 }
1223 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
1224 if let syn::Expr::Lit(syn::ExprLit {
1225 lit: syn::Lit::Str(lit_str),
1226 ..
1227 }) = nv.value
1228 {
1229 config.rename = Some(lit_str.value());
1230 }
1231 }
1232 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
1233 if let syn::Expr::Lit(syn::ExprLit {
1234 lit: syn::Lit::Str(lit_str),
1235 ..
1236 }) = nv.value
1237 {
1238 config.format_with = Some(lit_str.value());
1239 }
1240 }
1241 Meta::Path(path) if path.is_ident("image") => {
1242 config.image = true;
1243 }
1244 _ => {}
1245 }
1246 }
1247
1248 if let Some(name) = target_name {
1249 config.include_only = true;
1250 configs.push((name, config));
1251 }
1252 }
1253 }
1254 }
1255
1256 configs
1257}
1258
1259fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
1261 let mut targets = Vec::new();
1262
1263 for attr in attrs {
1264 if attr.path().is_ident("prompt_for")
1265 && let Ok(meta_list) = attr.meta.require_list()
1266 && let Ok(metas) =
1267 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1268 {
1269 let mut target_name = None;
1270 let mut template = None;
1271
1272 for meta in metas {
1273 match meta {
1274 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1275 if let syn::Expr::Lit(syn::ExprLit {
1276 lit: syn::Lit::Str(lit_str),
1277 ..
1278 }) = nv.value
1279 {
1280 target_name = Some(lit_str.value());
1281 }
1282 }
1283 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1284 if let syn::Expr::Lit(syn::ExprLit {
1285 lit: syn::Lit::Str(lit_str),
1286 ..
1287 }) = nv.value
1288 {
1289 template = Some(lit_str.value());
1290 }
1291 }
1292 _ => {}
1293 }
1294 }
1295
1296 if let Some(name) = target_name {
1297 targets.push(TargetInfo {
1298 name,
1299 template,
1300 field_configs: std::collections::HashMap::new(),
1301 });
1302 }
1303 }
1304 }
1305
1306 targets
1307}
1308
1309#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
1310pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
1311 let input = parse_macro_input!(input as DeriveInput);
1312
1313 let found_crate =
1314 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1315 let crate_path = match found_crate {
1316 FoundCrate::Itself => {
1317 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1319 quote!(::#ident)
1320 }
1321 FoundCrate::Name(name) => {
1322 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1323 quote!(::#ident)
1324 }
1325 };
1326
1327 let data_struct = match &input.data {
1329 Data::Struct(data) => data,
1330 _ => {
1331 return syn::Error::new(
1332 input.ident.span(),
1333 "`#[derive(ToPromptSet)]` is only supported for structs",
1334 )
1335 .to_compile_error()
1336 .into();
1337 }
1338 };
1339
1340 let fields = match &data_struct.fields {
1341 syn::Fields::Named(fields) => &fields.named,
1342 _ => {
1343 return syn::Error::new(
1344 input.ident.span(),
1345 "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
1346 )
1347 .to_compile_error()
1348 .into();
1349 }
1350 };
1351
1352 let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
1354
1355 for field in fields.iter() {
1357 let field_name = field.ident.as_ref().unwrap().to_string();
1358 let field_configs = parse_prompt_for_attrs(&field.attrs);
1359
1360 for (target_name, config) in field_configs {
1361 if target_name == "*" {
1362 for target in &mut targets {
1364 target
1365 .field_configs
1366 .entry(field_name.clone())
1367 .or_insert_with(FieldTargetConfig::default)
1368 .skip = config.skip;
1369 }
1370 } else {
1371 let target_exists = targets.iter().any(|t| t.name == target_name);
1373 if !target_exists {
1374 targets.push(TargetInfo {
1376 name: target_name.clone(),
1377 template: None,
1378 field_configs: std::collections::HashMap::new(),
1379 });
1380 }
1381
1382 let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
1383
1384 target.field_configs.insert(field_name.clone(), config);
1385 }
1386 }
1387 }
1388
1389 let mut match_arms = Vec::new();
1391
1392 for target in &targets {
1393 let target_name = &target.name;
1394
1395 if let Some(template_str) = &target.template {
1396 let mut image_parts = Vec::new();
1398
1399 for field in fields.iter() {
1400 let field_name = field.ident.as_ref().unwrap();
1401 let field_name_str = field_name.to_string();
1402
1403 if let Some(config) = target.field_configs.get(&field_name_str)
1404 && config.image
1405 {
1406 image_parts.push(quote! {
1407 parts.extend(self.#field_name.to_prompt_parts());
1408 });
1409 }
1410 }
1411
1412 match_arms.push(quote! {
1413 #target_name => {
1414 let mut parts = Vec::new();
1415
1416 #(#image_parts)*
1417
1418 let text = #crate_path::prompt::render_prompt(#template_str, self)
1419 .map_err(|e| #crate_path::prompt::PromptSetError::RenderFailed {
1420 target: #target_name.to_string(),
1421 source: e,
1422 })?;
1423
1424 if !text.is_empty() {
1425 parts.push(#crate_path::prompt::PromptPart::Text(text));
1426 }
1427
1428 Ok(parts)
1429 }
1430 });
1431 } else {
1432 let mut text_field_parts = Vec::new();
1434 let mut image_field_parts = Vec::new();
1435
1436 for field in fields.iter() {
1437 let field_name = field.ident.as_ref().unwrap();
1438 let field_name_str = field_name.to_string();
1439
1440 let config = target.field_configs.get(&field_name_str);
1442
1443 if let Some(cfg) = config
1445 && cfg.skip
1446 {
1447 continue;
1448 }
1449
1450 let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
1454 let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
1455 .iter()
1456 .any(|(name, _)| name != "*");
1457
1458 if has_any_target_specific_config && !is_explicitly_for_this_target {
1459 continue;
1460 }
1461
1462 if let Some(cfg) = config {
1463 if cfg.image {
1464 image_field_parts.push(quote! {
1465 parts.extend(self.#field_name.to_prompt_parts());
1466 });
1467 } else {
1468 let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
1469
1470 let value_expr = if let Some(format_with) = &cfg.format_with {
1471 match syn::parse_str::<syn::Path>(format_with) {
1473 Ok(func_path) => quote! { #func_path(&self.#field_name) },
1474 Err(_) => {
1475 let error_msg = format!(
1477 "Invalid function path in format_with: '{}'",
1478 format_with
1479 );
1480 quote! {
1481 compile_error!(#error_msg);
1482 String::new()
1483 }
1484 }
1485 }
1486 } else {
1487 quote! { self.#field_name.to_prompt() }
1488 };
1489
1490 text_field_parts.push(quote! {
1491 text_parts.push(format!("{}: {}", #key, #value_expr));
1492 });
1493 }
1494 } else {
1495 text_field_parts.push(quote! {
1497 text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
1498 });
1499 }
1500 }
1501
1502 match_arms.push(quote! {
1503 #target_name => {
1504 let mut parts = Vec::new();
1505
1506 #(#image_field_parts)*
1507
1508 let mut text_parts = Vec::new();
1509 #(#text_field_parts)*
1510
1511 if !text_parts.is_empty() {
1512 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1513 }
1514
1515 Ok(parts)
1516 }
1517 });
1518 }
1519 }
1520
1521 let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
1523
1524 match_arms.push(quote! {
1526 _ => {
1527 let available = vec![#(#target_names.to_string()),*];
1528 Err(#crate_path::prompt::PromptSetError::TargetNotFound {
1529 target: target.to_string(),
1530 available,
1531 })
1532 }
1533 });
1534
1535 let struct_name = &input.ident;
1536 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1537
1538 let expanded = quote! {
1539 impl #impl_generics #crate_path::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
1540 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<#crate_path::prompt::PromptPart>, #crate_path::prompt::PromptSetError> {
1541 match target {
1542 #(#match_arms)*
1543 }
1544 }
1545 }
1546 };
1547
1548 TokenStream::from(expanded)
1549}
1550
1551struct TypeList {
1553 types: Punctuated<syn::Type, Token![,]>,
1554}
1555
1556impl Parse for TypeList {
1557 fn parse(input: ParseStream) -> syn::Result<Self> {
1558 Ok(TypeList {
1559 types: Punctuated::parse_terminated(input)?,
1560 })
1561 }
1562}
1563
1564#[proc_macro]
1588pub fn examples_section(input: TokenStream) -> TokenStream {
1589 let input = parse_macro_input!(input as TypeList);
1590
1591 let found_crate =
1592 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1593 let _crate_path = match found_crate {
1594 FoundCrate::Itself => quote!(crate),
1595 FoundCrate::Name(name) => {
1596 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1597 quote!(::#ident)
1598 }
1599 };
1600
1601 let mut type_sections = Vec::new();
1603
1604 for ty in input.types.iter() {
1605 let type_name_str = quote!(#ty).to_string();
1607
1608 type_sections.push(quote! {
1610 {
1611 let type_name = #type_name_str;
1612 let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
1613 format!("---\n#### `{}`\n{}", type_name, json_example)
1614 }
1615 });
1616 }
1617
1618 let expanded = quote! {
1620 {
1621 let mut sections = Vec::new();
1622 sections.push("---".to_string());
1623 sections.push("### Examples".to_string());
1624 sections.push("".to_string());
1625 sections.push("Here are examples of the data structures you should use.".to_string());
1626 sections.push("".to_string());
1627
1628 #(sections.push(#type_sections);)*
1629
1630 sections.push("---".to_string());
1631
1632 sections.join("\n")
1633 }
1634 };
1635
1636 TokenStream::from(expanded)
1637}
1638
1639fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
1641 for attr in attrs {
1642 if attr.path().is_ident("prompt_for")
1643 && let Ok(meta_list) = attr.meta.require_list()
1644 && let Ok(metas) =
1645 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1646 {
1647 let mut target_type = None;
1648 let mut template = None;
1649
1650 for meta in metas {
1651 match meta {
1652 Meta::NameValue(nv) if nv.path.is_ident("target") => {
1653 if let syn::Expr::Lit(syn::ExprLit {
1654 lit: syn::Lit::Str(lit_str),
1655 ..
1656 }) = nv.value
1657 {
1658 target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
1660 }
1661 }
1662 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1663 if let syn::Expr::Lit(syn::ExprLit {
1664 lit: syn::Lit::Str(lit_str),
1665 ..
1666 }) = nv.value
1667 {
1668 template = Some(lit_str.value());
1669 }
1670 }
1671 _ => {}
1672 }
1673 }
1674
1675 if let (Some(target), Some(tmpl)) = (target_type, template) {
1676 return (target, tmpl);
1677 }
1678 }
1679 }
1680
1681 panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
1682}
1683
1684#[proc_macro_attribute]
1718pub fn define_intent(_attr: TokenStream, item: TokenStream) -> TokenStream {
1719 let input = parse_macro_input!(item as DeriveInput);
1720
1721 let found_crate =
1722 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1723 let crate_path = match found_crate {
1724 FoundCrate::Itself => {
1725 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1727 quote!(::#ident)
1728 }
1729 FoundCrate::Name(name) => {
1730 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1731 quote!(::#ident)
1732 }
1733 };
1734
1735 let enum_data = match &input.data {
1737 Data::Enum(data) => data,
1738 _ => {
1739 return syn::Error::new(
1740 input.ident.span(),
1741 "`#[define_intent]` can only be applied to enums",
1742 )
1743 .to_compile_error()
1744 .into();
1745 }
1746 };
1747
1748 let mut prompt_template = None;
1750 let mut extractor_tag = None;
1751 let mut mode = None;
1752
1753 for attr in &input.attrs {
1754 if attr.path().is_ident("intent")
1755 && let Ok(metas) =
1756 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1757 {
1758 for meta in metas {
1759 match meta {
1760 Meta::NameValue(nv) if nv.path.is_ident("prompt") => {
1761 if let syn::Expr::Lit(syn::ExprLit {
1762 lit: syn::Lit::Str(lit_str),
1763 ..
1764 }) = nv.value
1765 {
1766 prompt_template = Some(lit_str.value());
1767 }
1768 }
1769 Meta::NameValue(nv) if nv.path.is_ident("extractor_tag") => {
1770 if let syn::Expr::Lit(syn::ExprLit {
1771 lit: syn::Lit::Str(lit_str),
1772 ..
1773 }) = nv.value
1774 {
1775 extractor_tag = Some(lit_str.value());
1776 }
1777 }
1778 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
1779 if let syn::Expr::Lit(syn::ExprLit {
1780 lit: syn::Lit::Str(lit_str),
1781 ..
1782 }) = nv.value
1783 {
1784 mode = Some(lit_str.value());
1785 }
1786 }
1787 _ => {}
1788 }
1789 }
1790 }
1791 }
1792
1793 let mode = mode.unwrap_or_else(|| "single".to_string());
1795
1796 if mode != "single" && mode != "multi_tag" {
1798 return syn::Error::new(
1799 input.ident.span(),
1800 "`mode` must be either \"single\" or \"multi_tag\"",
1801 )
1802 .to_compile_error()
1803 .into();
1804 }
1805
1806 let prompt_template = match prompt_template {
1808 Some(p) => p,
1809 None => {
1810 return syn::Error::new(
1811 input.ident.span(),
1812 "`#[intent(...)]` attribute must include `prompt = \"...\"`",
1813 )
1814 .to_compile_error()
1815 .into();
1816 }
1817 };
1818
1819 if mode == "multi_tag" {
1821 let enum_name = &input.ident;
1822 let actions_doc = generate_multi_tag_actions_doc(&enum_data.variants);
1823 return generate_multi_tag_output(
1824 &input,
1825 enum_name,
1826 enum_data,
1827 prompt_template,
1828 actions_doc,
1829 );
1830 }
1831
1832 let extractor_tag = match extractor_tag {
1834 Some(t) => t,
1835 None => {
1836 return syn::Error::new(
1837 input.ident.span(),
1838 "`#[intent(...)]` attribute must include `extractor_tag = \"...\"`",
1839 )
1840 .to_compile_error()
1841 .into();
1842 }
1843 };
1844
1845 let enum_name = &input.ident;
1847 let enum_docs = extract_doc_comments(&input.attrs);
1848
1849 let mut intents_doc_lines = Vec::new();
1850
1851 if !enum_docs.is_empty() {
1853 intents_doc_lines.push(format!("{}: {}", enum_name, enum_docs));
1854 } else {
1855 intents_doc_lines.push(format!("{}:", enum_name));
1856 }
1857 intents_doc_lines.push(String::new()); intents_doc_lines.push("Possible values:".to_string());
1859
1860 for variant in &enum_data.variants {
1862 let variant_name = &variant.ident;
1863 let variant_docs = extract_doc_comments(&variant.attrs);
1864
1865 if !variant_docs.is_empty() {
1866 intents_doc_lines.push(format!("- {}: {}", variant_name, variant_docs));
1867 } else {
1868 intents_doc_lines.push(format!("- {}", variant_name));
1869 }
1870 }
1871
1872 let intents_doc_str = intents_doc_lines.join("\n");
1873
1874 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
1876 let user_variables: Vec<String> = placeholders
1877 .iter()
1878 .filter_map(|(name, _)| {
1879 if name != "intents_doc" {
1880 Some(name.clone())
1881 } else {
1882 None
1883 }
1884 })
1885 .collect();
1886
1887 let enum_name_str = enum_name.to_string();
1889 let snake_case_name = to_snake_case(&enum_name_str);
1890 let function_name = syn::Ident::new(
1891 &format!("build_{}_prompt", snake_case_name),
1892 proc_macro2::Span::call_site(),
1893 );
1894
1895 let function_params: Vec<proc_macro2::TokenStream> = user_variables
1897 .iter()
1898 .map(|var| {
1899 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
1900 quote! { #ident: &str }
1901 })
1902 .collect();
1903
1904 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
1906 .iter()
1907 .map(|var| {
1908 let var_str = var.clone();
1909 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
1910 quote! {
1911 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
1912 }
1913 })
1914 .collect();
1915
1916 let converted_template = prompt_template.clone();
1918
1919 let extractor_name = syn::Ident::new(
1921 &format!("{}Extractor", enum_name),
1922 proc_macro2::Span::call_site(),
1923 );
1924
1925 let filtered_attrs: Vec<_> = input
1927 .attrs
1928 .iter()
1929 .filter(|attr| !attr.path().is_ident("intent"))
1930 .collect();
1931
1932 let vis = &input.vis;
1934 let generics = &input.generics;
1935 let variants = &enum_data.variants;
1936 let enum_output = quote! {
1937 #(#filtered_attrs)*
1938 #vis enum #enum_name #generics {
1939 #variants
1940 }
1941 };
1942
1943 let expanded = quote! {
1945 #enum_output
1947
1948 pub fn #function_name(#(#function_params),*) -> String {
1950 let mut env = minijinja::Environment::new();
1951 env.add_template("prompt", #converted_template)
1952 .expect("Failed to parse intent prompt template");
1953
1954 let tmpl = env.get_template("prompt").unwrap();
1955
1956 let mut __template_context = std::collections::HashMap::new();
1957
1958 __template_context.insert("intents_doc".to_string(), minijinja::Value::from(#intents_doc_str));
1960
1961 #(#context_insertions)*
1963
1964 tmpl.render(&__template_context)
1965 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
1966 }
1967
1968 pub struct #extractor_name;
1970
1971 impl #extractor_name {
1972 pub const EXTRACTOR_TAG: &'static str = #extractor_tag;
1973 }
1974
1975 impl #crate_path::intent::IntentExtractor<#enum_name> for #extractor_name {
1976 fn extract_intent(&self, response: &str) -> Result<#enum_name, #crate_path::intent::IntentExtractionError> {
1977 #crate_path::intent::extract_intent_from_response(response, Self::EXTRACTOR_TAG)
1979 }
1980 }
1981 };
1982
1983 TokenStream::from(expanded)
1984}
1985
1986fn to_snake_case(s: &str) -> String {
1988 let mut result = String::new();
1989 let mut prev_upper = false;
1990
1991 for (i, ch) in s.chars().enumerate() {
1992 if ch.is_uppercase() {
1993 if i > 0 && !prev_upper {
1994 result.push('_');
1995 }
1996 result.push(ch.to_lowercase().next().unwrap());
1997 prev_upper = true;
1998 } else {
1999 result.push(ch);
2000 prev_upper = false;
2001 }
2002 }
2003
2004 result
2005}
2006
2007#[derive(Debug, Default)]
2009struct ActionAttrs {
2010 tag: Option<String>,
2011}
2012
2013fn parse_action_attrs(attrs: &[syn::Attribute]) -> ActionAttrs {
2014 let mut result = ActionAttrs::default();
2015
2016 for attr in attrs {
2017 if attr.path().is_ident("action")
2018 && let Ok(meta_list) = attr.meta.require_list()
2019 && let Ok(metas) =
2020 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2021 {
2022 for meta in metas {
2023 if let Meta::NameValue(nv) = meta
2024 && nv.path.is_ident("tag")
2025 && let syn::Expr::Lit(syn::ExprLit {
2026 lit: syn::Lit::Str(lit_str),
2027 ..
2028 }) = nv.value
2029 {
2030 result.tag = Some(lit_str.value());
2031 }
2032 }
2033 }
2034 }
2035
2036 result
2037}
2038
2039#[derive(Debug, Default)]
2041struct FieldActionAttrs {
2042 is_attribute: bool,
2043 is_inner_text: bool,
2044}
2045
2046fn parse_field_action_attrs(attrs: &[syn::Attribute]) -> FieldActionAttrs {
2047 let mut result = FieldActionAttrs::default();
2048
2049 for attr in attrs {
2050 if attr.path().is_ident("action")
2051 && let Ok(meta_list) = attr.meta.require_list()
2052 {
2053 let tokens_str = meta_list.tokens.to_string();
2054 if tokens_str == "attribute" {
2055 result.is_attribute = true;
2056 } else if tokens_str == "inner_text" {
2057 result.is_inner_text = true;
2058 }
2059 }
2060 }
2061
2062 result
2063}
2064
2065fn generate_multi_tag_actions_doc(
2067 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2068) -> String {
2069 let mut doc_lines = Vec::new();
2070
2071 for variant in variants {
2072 let action_attrs = parse_action_attrs(&variant.attrs);
2073
2074 if let Some(tag) = action_attrs.tag {
2075 let variant_docs = extract_doc_comments(&variant.attrs);
2076
2077 match &variant.fields {
2078 syn::Fields::Unit => {
2079 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2081 }
2082 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
2083 doc_lines.push(format!("- `<{}>...</{}>`: {}", tag, tag, variant_docs));
2085 }
2086 syn::Fields::Named(fields) => {
2087 let mut attrs_str = Vec::new();
2089 let mut has_inner_text = false;
2090
2091 for field in &fields.named {
2092 let field_name = field.ident.as_ref().unwrap();
2093 let field_attrs = parse_field_action_attrs(&field.attrs);
2094
2095 if field_attrs.is_attribute {
2096 attrs_str.push(format!("{}=\"...\"", field_name));
2097 } else if field_attrs.is_inner_text {
2098 has_inner_text = true;
2099 }
2100 }
2101
2102 let attrs_part = if !attrs_str.is_empty() {
2103 format!(" {}", attrs_str.join(" "))
2104 } else {
2105 String::new()
2106 };
2107
2108 if has_inner_text {
2109 doc_lines.push(format!(
2110 "- `<{}{}>...</{}>`: {}",
2111 tag, attrs_part, tag, variant_docs
2112 ));
2113 } else if !attrs_str.is_empty() {
2114 doc_lines.push(format!("- `<{}{} />`: {}", tag, attrs_part, variant_docs));
2115 } else {
2116 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2117 }
2118
2119 for field in &fields.named {
2121 let field_name = field.ident.as_ref().unwrap();
2122 let field_attrs = parse_field_action_attrs(&field.attrs);
2123 let field_docs = extract_doc_comments(&field.attrs);
2124
2125 if field_attrs.is_attribute {
2126 doc_lines
2127 .push(format!(" - `{}` (attribute): {}", field_name, field_docs));
2128 } else if field_attrs.is_inner_text {
2129 doc_lines
2130 .push(format!(" - `{}` (inner_text): {}", field_name, field_docs));
2131 }
2132 }
2133 }
2134 _ => {
2135 }
2137 }
2138 }
2139 }
2140
2141 doc_lines.join("\n")
2142}
2143
2144fn generate_tags_regex(
2146 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2147) -> String {
2148 let mut tag_names = Vec::new();
2149
2150 for variant in variants {
2151 let action_attrs = parse_action_attrs(&variant.attrs);
2152 if let Some(tag) = action_attrs.tag {
2153 tag_names.push(tag);
2154 }
2155 }
2156
2157 if tag_names.is_empty() {
2158 return String::new();
2159 }
2160
2161 let tags_pattern = tag_names.join("|");
2162 format!(
2165 r"(?is)<(?:{})\b[^>]*/>|<(?:{})\b[^>]*>.*?</(?:{})>",
2166 tags_pattern, tags_pattern, tags_pattern
2167 )
2168}
2169
2170fn generate_multi_tag_output(
2172 input: &DeriveInput,
2173 enum_name: &syn::Ident,
2174 enum_data: &syn::DataEnum,
2175 prompt_template: String,
2176 actions_doc: String,
2177) -> TokenStream {
2178 let found_crate =
2179 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2180 let crate_path = match found_crate {
2181 FoundCrate::Itself => {
2182 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2184 quote!(::#ident)
2185 }
2186 FoundCrate::Name(name) => {
2187 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2188 quote!(::#ident)
2189 }
2190 };
2191
2192 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2194 let user_variables: Vec<String> = placeholders
2195 .iter()
2196 .filter_map(|(name, _)| {
2197 if name != "actions_doc" {
2198 Some(name.clone())
2199 } else {
2200 None
2201 }
2202 })
2203 .collect();
2204
2205 let enum_name_str = enum_name.to_string();
2207 let snake_case_name = to_snake_case(&enum_name_str);
2208 let function_name = syn::Ident::new(
2209 &format!("build_{}_prompt", snake_case_name),
2210 proc_macro2::Span::call_site(),
2211 );
2212
2213 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2215 .iter()
2216 .map(|var| {
2217 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2218 quote! { #ident: &str }
2219 })
2220 .collect();
2221
2222 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2224 .iter()
2225 .map(|var| {
2226 let var_str = var.clone();
2227 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2228 quote! {
2229 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2230 }
2231 })
2232 .collect();
2233
2234 let extractor_name = syn::Ident::new(
2236 &format!("{}Extractor", enum_name),
2237 proc_macro2::Span::call_site(),
2238 );
2239
2240 let filtered_attrs: Vec<_> = input
2242 .attrs
2243 .iter()
2244 .filter(|attr| !attr.path().is_ident("intent"))
2245 .collect();
2246
2247 let filtered_variants: Vec<proc_macro2::TokenStream> = enum_data
2249 .variants
2250 .iter()
2251 .map(|variant| {
2252 let variant_name = &variant.ident;
2253 let variant_attrs: Vec<_> = variant
2254 .attrs
2255 .iter()
2256 .filter(|attr| !attr.path().is_ident("action"))
2257 .collect();
2258 let fields = &variant.fields;
2259
2260 let filtered_fields = match fields {
2262 syn::Fields::Named(named_fields) => {
2263 let filtered: Vec<_> = named_fields
2264 .named
2265 .iter()
2266 .map(|field| {
2267 let field_name = &field.ident;
2268 let field_type = &field.ty;
2269 let field_vis = &field.vis;
2270 let filtered_attrs: Vec<_> = field
2271 .attrs
2272 .iter()
2273 .filter(|attr| !attr.path().is_ident("action"))
2274 .collect();
2275 quote! {
2276 #(#filtered_attrs)*
2277 #field_vis #field_name: #field_type
2278 }
2279 })
2280 .collect();
2281 quote! { { #(#filtered,)* } }
2282 }
2283 syn::Fields::Unnamed(unnamed_fields) => {
2284 let types: Vec<_> = unnamed_fields
2285 .unnamed
2286 .iter()
2287 .map(|field| {
2288 let field_type = &field.ty;
2289 quote! { #field_type }
2290 })
2291 .collect();
2292 quote! { (#(#types),*) }
2293 }
2294 syn::Fields::Unit => quote! {},
2295 };
2296
2297 quote! {
2298 #(#variant_attrs)*
2299 #variant_name #filtered_fields
2300 }
2301 })
2302 .collect();
2303
2304 let vis = &input.vis;
2305 let generics = &input.generics;
2306
2307 let parsing_arms = generate_parsing_arms(&enum_data.variants, enum_name);
2309
2310 let tags_regex = generate_tags_regex(&enum_data.variants);
2312
2313 let expanded = quote! {
2314 #(#filtered_attrs)*
2316 #vis enum #enum_name #generics {
2317 #(#filtered_variants),*
2318 }
2319
2320 pub fn #function_name(#(#function_params),*) -> String {
2322 let mut env = minijinja::Environment::new();
2323 env.add_template("prompt", #prompt_template)
2324 .expect("Failed to parse intent prompt template");
2325
2326 let tmpl = env.get_template("prompt").unwrap();
2327
2328 let mut __template_context = std::collections::HashMap::new();
2329
2330 __template_context.insert("actions_doc".to_string(), minijinja::Value::from(#actions_doc));
2332
2333 #(#context_insertions)*
2335
2336 tmpl.render(&__template_context)
2337 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2338 }
2339
2340 pub struct #extractor_name;
2342
2343 impl #extractor_name {
2344 fn parse_single_action(&self, text: &str) -> Option<#enum_name> {
2345 use ::quick_xml::events::Event;
2346 use ::quick_xml::Reader;
2347
2348 let mut actions = Vec::new();
2349 let mut reader = Reader::from_str(text);
2350 reader.config_mut().trim_text(true);
2351
2352 let mut buf = Vec::new();
2353
2354 loop {
2355 match reader.read_event_into(&mut buf) {
2356 Ok(Event::Start(e)) => {
2357 let owned_e = e.into_owned();
2358 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2359 let is_empty = false;
2360
2361 #parsing_arms
2362 }
2363 Ok(Event::Empty(e)) => {
2364 let owned_e = e.into_owned();
2365 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2366 let is_empty = true;
2367
2368 #parsing_arms
2369 }
2370 Ok(Event::Eof) => break,
2371 Err(_) => {
2372 break;
2374 }
2375 _ => {}
2376 }
2377 buf.clear();
2378 }
2379
2380 actions.into_iter().next()
2381 }
2382
2383 pub fn extract_actions(&self, text: &str) -> Result<Vec<#enum_name>, #crate_path::intent::IntentError> {
2384 use ::quick_xml::events::Event;
2385 use ::quick_xml::Reader;
2386
2387 let mut actions = Vec::new();
2388 let mut reader = Reader::from_str(text);
2389 reader.config_mut().trim_text(true);
2390
2391 let mut buf = Vec::new();
2392
2393 loop {
2394 match reader.read_event_into(&mut buf) {
2395 Ok(Event::Start(e)) => {
2396 let owned_e = e.into_owned();
2397 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2398 let is_empty = false;
2399
2400 #parsing_arms
2401 }
2402 Ok(Event::Empty(e)) => {
2403 let owned_e = e.into_owned();
2404 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2405 let is_empty = true;
2406
2407 #parsing_arms
2408 }
2409 Ok(Event::Eof) => break,
2410 Err(_) => {
2411 break;
2413 }
2414 _ => {}
2415 }
2416 buf.clear();
2417 }
2418
2419 Ok(actions)
2420 }
2421
2422 pub fn transform_actions<F>(&self, text: &str, mut transformer: F) -> String
2423 where
2424 F: FnMut(#enum_name) -> String,
2425 {
2426 use ::regex::Regex;
2427
2428 let regex_pattern = #tags_regex;
2429 if regex_pattern.is_empty() {
2430 return text.to_string();
2431 }
2432
2433 let re = Regex::new(®ex_pattern).unwrap_or_else(|e| {
2434 panic!("Failed to compile regex for action tags: {}", e);
2435 });
2436
2437 re.replace_all(text, |caps: &::regex::Captures| {
2438 let matched = caps.get(0).map(|m| m.as_str()).unwrap_or("");
2439
2440 if let Some(action) = self.parse_single_action(matched) {
2442 transformer(action)
2443 } else {
2444 matched.to_string()
2446 }
2447 }).to_string()
2448 }
2449
2450 pub fn strip_actions(&self, text: &str) -> String {
2451 self.transform_actions(text, |_| String::new())
2452 }
2453 }
2454 };
2455
2456 TokenStream::from(expanded)
2457}
2458
2459fn generate_parsing_arms(
2461 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2462 enum_name: &syn::Ident,
2463) -> proc_macro2::TokenStream {
2464 let mut arms = Vec::new();
2465
2466 for variant in variants {
2467 let variant_name = &variant.ident;
2468 let action_attrs = parse_action_attrs(&variant.attrs);
2469
2470 if let Some(tag) = action_attrs.tag {
2471 match &variant.fields {
2472 syn::Fields::Unit => {
2473 arms.push(quote! {
2475 if &tag_name == #tag {
2476 actions.push(#enum_name::#variant_name);
2477 }
2478 });
2479 }
2480 syn::Fields::Unnamed(_fields) => {
2481 arms.push(quote! {
2483 if &tag_name == #tag && !is_empty {
2484 match reader.read_text(owned_e.name()) {
2486 Ok(text) => {
2487 actions.push(#enum_name::#variant_name(text.to_string()));
2488 }
2489 Err(_) => {
2490 actions.push(#enum_name::#variant_name(String::new()));
2492 }
2493 }
2494 }
2495 });
2496 }
2497 syn::Fields::Named(fields) => {
2498 let mut field_names = Vec::new();
2500 let mut has_inner_text_field = None;
2501
2502 for field in &fields.named {
2503 let field_name = field.ident.as_ref().unwrap();
2504 let field_attrs = parse_field_action_attrs(&field.attrs);
2505
2506 if field_attrs.is_attribute {
2507 field_names.push(field_name.clone());
2508 } else if field_attrs.is_inner_text {
2509 has_inner_text_field = Some(field_name.clone());
2510 }
2511 }
2512
2513 if let Some(inner_text_field) = has_inner_text_field {
2514 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2517 quote! {
2518 let mut #field_name = String::new();
2519 for attr in owned_e.attributes() {
2520 if let Ok(attr) = attr {
2521 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2522 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2523 break;
2524 }
2525 }
2526 }
2527 }
2528 }).collect();
2529
2530 arms.push(quote! {
2531 if &tag_name == #tag {
2532 #(#attr_extractions)*
2533
2534 if is_empty {
2536 let #inner_text_field = String::new();
2537 actions.push(#enum_name::#variant_name {
2538 #(#field_names,)*
2539 #inner_text_field,
2540 });
2541 } else {
2542 match reader.read_text(owned_e.name()) {
2544 Ok(text) => {
2545 let #inner_text_field = text.to_string();
2546 actions.push(#enum_name::#variant_name {
2547 #(#field_names,)*
2548 #inner_text_field,
2549 });
2550 }
2551 Err(_) => {
2552 let #inner_text_field = String::new();
2554 actions.push(#enum_name::#variant_name {
2555 #(#field_names,)*
2556 #inner_text_field,
2557 });
2558 }
2559 }
2560 }
2561 }
2562 });
2563 } else {
2564 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2566 quote! {
2567 let mut #field_name = String::new();
2568 for attr in owned_e.attributes() {
2569 if let Ok(attr) = attr {
2570 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2571 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2572 break;
2573 }
2574 }
2575 }
2576 }
2577 }).collect();
2578
2579 arms.push(quote! {
2580 if &tag_name == #tag {
2581 #(#attr_extractions)*
2582 actions.push(#enum_name::#variant_name {
2583 #(#field_names),*
2584 });
2585 }
2586 });
2587 }
2588 }
2589 }
2590 }
2591 }
2592
2593 quote! {
2594 #(#arms)*
2595 }
2596}
2597
2598#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
2600pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
2601 let input = parse_macro_input!(input as DeriveInput);
2602
2603 let found_crate =
2604 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2605 let crate_path = match found_crate {
2606 FoundCrate::Itself => {
2607 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2609 quote!(::#ident)
2610 }
2611 FoundCrate::Name(name) => {
2612 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2613 quote!(::#ident)
2614 }
2615 };
2616
2617 let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
2619
2620 let struct_name = &input.ident;
2621 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2622
2623 let placeholders = parse_template_placeholders_with_mode(&template);
2625
2626 let mut converted_template = template.clone();
2628 let mut context_fields = Vec::new();
2629
2630 let fields = match &input.data {
2632 Data::Struct(data_struct) => match &data_struct.fields {
2633 syn::Fields::Named(fields) => &fields.named,
2634 _ => panic!("ToPromptFor is only supported for structs with named fields"),
2635 },
2636 _ => panic!("ToPromptFor is only supported for structs"),
2637 };
2638
2639 let has_mode_support = input.attrs.iter().any(|attr| {
2641 if attr.path().is_ident("prompt")
2642 && let Ok(metas) =
2643 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2644 {
2645 for meta in metas {
2646 if let Meta::NameValue(nv) = meta
2647 && nv.path.is_ident("mode")
2648 {
2649 return true;
2650 }
2651 }
2652 }
2653 false
2654 });
2655
2656 for (placeholder_name, mode_opt) in &placeholders {
2658 if placeholder_name == "self" {
2659 if let Some(specific_mode) = mode_opt {
2660 let unique_key = format!("self__{}", specific_mode);
2662
2663 let pattern = format!("{{{{ self:{} }}}}", specific_mode);
2665 let replacement = format!("{{{{ {} }}}}", unique_key);
2666 converted_template = converted_template.replace(&pattern, &replacement);
2667
2668 context_fields.push(quote! {
2670 context.insert(
2671 #unique_key.to_string(),
2672 minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
2673 );
2674 });
2675 } else {
2676 if has_mode_support {
2679 context_fields.push(quote! {
2681 context.insert(
2682 "self".to_string(),
2683 minijinja::Value::from(self.to_prompt_with_mode(mode))
2684 );
2685 });
2686 } else {
2687 context_fields.push(quote! {
2689 context.insert(
2690 "self".to_string(),
2691 minijinja::Value::from(self.to_prompt())
2692 );
2693 });
2694 }
2695 }
2696 } else {
2697 let field_exists = fields.iter().any(|f| {
2700 f.ident
2701 .as_ref()
2702 .is_some_and(|ident| ident == placeholder_name)
2703 });
2704
2705 if field_exists {
2706 let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
2707
2708 context_fields.push(quote! {
2712 context.insert(
2713 #placeholder_name.to_string(),
2714 minijinja::Value::from_serialize(&self.#field_ident)
2715 );
2716 });
2717 }
2718 }
2720 }
2721
2722 let expanded = quote! {
2723 impl #impl_generics #crate_path::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
2724 where
2725 #target_type: serde::Serialize,
2726 {
2727 fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
2728 let mut env = minijinja::Environment::new();
2730 env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
2731 panic!("Failed to parse template: {}", e)
2732 });
2733
2734 let tmpl = env.get_template("prompt").unwrap();
2735
2736 let mut context = std::collections::HashMap::new();
2738 context.insert(
2740 "self".to_string(),
2741 minijinja::Value::from_serialize(self)
2742 );
2743 context.insert(
2745 "target".to_string(),
2746 minijinja::Value::from_serialize(target)
2747 );
2748 #(#context_fields)*
2749
2750 tmpl.render(context).unwrap_or_else(|e| {
2752 format!("Failed to render prompt: {}", e)
2753 })
2754 }
2755 }
2756 };
2757
2758 TokenStream::from(expanded)
2759}
2760
2761struct AgentAttrs {
2767 expertise: Option<String>,
2768 output: Option<syn::Type>,
2769 backend: Option<String>,
2770 model: Option<String>,
2771 inner: Option<String>,
2772 default_inner: Option<String>,
2773 max_retries: Option<u32>,
2774 profile: Option<String>,
2775}
2776
2777impl Parse for AgentAttrs {
2778 fn parse(input: ParseStream) -> syn::Result<Self> {
2779 let mut expertise = None;
2780 let mut output = None;
2781 let mut backend = None;
2782 let mut model = None;
2783 let mut inner = None;
2784 let mut default_inner = None;
2785 let mut max_retries = None;
2786 let mut profile = None;
2787
2788 let pairs = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
2789
2790 for meta in pairs {
2791 match meta {
2792 Meta::NameValue(nv) if nv.path.is_ident("expertise") => {
2793 if let syn::Expr::Lit(syn::ExprLit {
2794 lit: syn::Lit::Str(lit_str),
2795 ..
2796 }) = &nv.value
2797 {
2798 expertise = Some(lit_str.value());
2799 }
2800 }
2801 Meta::NameValue(nv) if nv.path.is_ident("output") => {
2802 if let syn::Expr::Lit(syn::ExprLit {
2803 lit: syn::Lit::Str(lit_str),
2804 ..
2805 }) = &nv.value
2806 {
2807 let ty: syn::Type = syn::parse_str(&lit_str.value())?;
2808 output = Some(ty);
2809 }
2810 }
2811 Meta::NameValue(nv) if nv.path.is_ident("backend") => {
2812 if let syn::Expr::Lit(syn::ExprLit {
2813 lit: syn::Lit::Str(lit_str),
2814 ..
2815 }) = &nv.value
2816 {
2817 backend = Some(lit_str.value());
2818 }
2819 }
2820 Meta::NameValue(nv) if nv.path.is_ident("model") => {
2821 if let syn::Expr::Lit(syn::ExprLit {
2822 lit: syn::Lit::Str(lit_str),
2823 ..
2824 }) = &nv.value
2825 {
2826 model = Some(lit_str.value());
2827 }
2828 }
2829 Meta::NameValue(nv) if nv.path.is_ident("inner") => {
2830 if let syn::Expr::Lit(syn::ExprLit {
2831 lit: syn::Lit::Str(lit_str),
2832 ..
2833 }) = &nv.value
2834 {
2835 inner = Some(lit_str.value());
2836 }
2837 }
2838 Meta::NameValue(nv) if nv.path.is_ident("default_inner") => {
2839 if let syn::Expr::Lit(syn::ExprLit {
2840 lit: syn::Lit::Str(lit_str),
2841 ..
2842 }) = &nv.value
2843 {
2844 default_inner = Some(lit_str.value());
2845 }
2846 }
2847 Meta::NameValue(nv) if nv.path.is_ident("max_retries") => {
2848 if let syn::Expr::Lit(syn::ExprLit {
2849 lit: syn::Lit::Int(lit_int),
2850 ..
2851 }) = &nv.value
2852 {
2853 max_retries = Some(lit_int.base10_parse()?);
2854 }
2855 }
2856 Meta::NameValue(nv) if nv.path.is_ident("profile") => {
2857 if let syn::Expr::Lit(syn::ExprLit {
2858 lit: syn::Lit::Str(lit_str),
2859 ..
2860 }) = &nv.value
2861 {
2862 profile = Some(lit_str.value());
2863 }
2864 }
2865 _ => {}
2866 }
2867 }
2868
2869 Ok(AgentAttrs {
2870 expertise,
2871 output,
2872 backend,
2873 model,
2874 inner,
2875 default_inner,
2876 max_retries,
2877 profile,
2878 })
2879 }
2880}
2881
2882fn parse_agent_attrs(attrs: &[syn::Attribute]) -> syn::Result<AgentAttrs> {
2884 for attr in attrs {
2885 if attr.path().is_ident("agent") {
2886 return attr.parse_args::<AgentAttrs>();
2887 }
2888 }
2889
2890 Ok(AgentAttrs {
2891 expertise: None,
2892 output: None,
2893 backend: None,
2894 model: None,
2895 inner: None,
2896 default_inner: None,
2897 max_retries: None,
2898 profile: None,
2899 })
2900}
2901
2902fn generate_backend_constructors(
2904 struct_name: &syn::Ident,
2905 backend: &str,
2906 _model: Option<&str>,
2907 _profile: Option<&str>,
2908 crate_path: &proc_macro2::TokenStream,
2909) -> proc_macro2::TokenStream {
2910 match backend {
2911 "claude" => {
2912 quote! {
2913 impl #struct_name {
2914 pub fn with_claude() -> Self {
2916 Self::new(#crate_path::agent::impls::ClaudeCodeAgent::new())
2917 }
2918
2919 pub fn with_claude_model(model: &str) -> Self {
2921 Self::new(
2922 #crate_path::agent::impls::ClaudeCodeAgent::new()
2923 .with_model_str(model)
2924 )
2925 }
2926 }
2927 }
2928 }
2929 "gemini" => {
2930 quote! {
2931 impl #struct_name {
2932 pub fn with_gemini() -> Self {
2934 Self::new(#crate_path::agent::impls::GeminiAgent::new())
2935 }
2936
2937 pub fn with_gemini_model(model: &str) -> Self {
2939 Self::new(
2940 #crate_path::agent::impls::GeminiAgent::new()
2941 .with_model_str(model)
2942 )
2943 }
2944 }
2945 }
2946 }
2947 _ => quote! {},
2948 }
2949}
2950
2951fn generate_default_impl(
2953 struct_name: &syn::Ident,
2954 backend: &str,
2955 model: Option<&str>,
2956 profile: Option<&str>,
2957 crate_path: &proc_macro2::TokenStream,
2958) -> proc_macro2::TokenStream {
2959 let profile_expr = if let Some(profile_str) = profile {
2961 match profile_str.to_lowercase().as_str() {
2962 "creative" => quote! { #crate_path::agent::ExecutionProfile::Creative },
2963 "balanced" => quote! { #crate_path::agent::ExecutionProfile::Balanced },
2964 "deterministic" => quote! { #crate_path::agent::ExecutionProfile::Deterministic },
2965 _ => quote! { #crate_path::agent::ExecutionProfile::Balanced }, }
2967 } else {
2968 quote! { #crate_path::agent::ExecutionProfile::default() }
2969 };
2970
2971 let agent_init = match backend {
2972 "gemini" => {
2973 let mut builder = quote! { #crate_path::agent::impls::GeminiAgent::new() };
2974
2975 if let Some(model_str) = model {
2976 builder = quote! { #builder.with_model_str(#model_str) };
2977 }
2978
2979 builder = quote! { #builder.with_execution_profile(#profile_expr) };
2980 builder
2981 }
2982 _ => {
2983 let mut builder = quote! { #crate_path::agent::impls::ClaudeCodeAgent::new() };
2985
2986 if let Some(model_str) = model {
2987 builder = quote! { #builder.with_model_str(#model_str) };
2988 }
2989
2990 builder = quote! { #builder.with_execution_profile(#profile_expr) };
2991 builder
2992 }
2993 };
2994
2995 quote! {
2996 impl Default for #struct_name {
2997 fn default() -> Self {
2998 Self::new(#agent_init)
2999 }
3000 }
3001 }
3002}
3003
3004#[proc_macro_derive(Agent, attributes(agent))]
3013pub fn derive_agent(input: TokenStream) -> TokenStream {
3014 let input = parse_macro_input!(input as DeriveInput);
3015 let struct_name = &input.ident;
3016
3017 let agent_attrs = match parse_agent_attrs(&input.attrs) {
3019 Ok(attrs) => attrs,
3020 Err(e) => return e.to_compile_error().into(),
3021 };
3022
3023 let expertise = agent_attrs
3024 .expertise
3025 .unwrap_or_else(|| String::from("general AI assistant"));
3026 let output_type = agent_attrs
3027 .output
3028 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3029 let backend = agent_attrs
3030 .backend
3031 .unwrap_or_else(|| String::from("claude"));
3032 let model = agent_attrs.model;
3033 let _profile = agent_attrs.profile; let max_retries = agent_attrs.max_retries.unwrap_or(3); let found_crate =
3038 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3039 let crate_path = match found_crate {
3040 FoundCrate::Itself => {
3041 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3043 quote!(::#ident)
3044 }
3045 FoundCrate::Name(name) => {
3046 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3047 quote!(::#ident)
3048 }
3049 };
3050
3051 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
3052
3053 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3055 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3056
3057 let enhanced_expertise = if is_string_output {
3059 quote! { #expertise }
3061 } else {
3062 let type_name = quote!(#output_type).to_string();
3064 quote! {
3065 {
3066 use std::sync::OnceLock;
3067 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3068
3069 EXPERTISE_CACHE.get_or_init(|| {
3070 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3072
3073 if schema.is_empty() {
3074 format!(
3076 concat!(
3077 #expertise,
3078 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3079 "Do not include any text outside the JSON object."
3080 ),
3081 #type_name
3082 )
3083 } else {
3084 format!(
3086 concat!(
3087 #expertise,
3088 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3089 ),
3090 schema
3091 )
3092 }
3093 }).as_str()
3094 }
3095 }
3096 };
3097
3098 let agent_init = match backend.as_str() {
3100 "gemini" => {
3101 if let Some(model_str) = model {
3102 quote! {
3103 use #crate_path::agent::impls::GeminiAgent;
3104 let agent = GeminiAgent::new().with_model_str(#model_str);
3105 }
3106 } else {
3107 quote! {
3108 use #crate_path::agent::impls::GeminiAgent;
3109 let agent = GeminiAgent::new();
3110 }
3111 }
3112 }
3113 "claude" => {
3114 if let Some(model_str) = model {
3115 quote! {
3116 use #crate_path::agent::impls::ClaudeCodeAgent;
3117 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3118 }
3119 } else {
3120 quote! {
3121 use #crate_path::agent::impls::ClaudeCodeAgent;
3122 let agent = ClaudeCodeAgent::new();
3123 }
3124 }
3125 }
3126 _ => {
3127 if let Some(model_str) = model {
3129 quote! {
3130 use #crate_path::agent::impls::ClaudeCodeAgent;
3131 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3132 }
3133 } else {
3134 quote! {
3135 use #crate_path::agent::impls::ClaudeCodeAgent;
3136 let agent = ClaudeCodeAgent::new();
3137 }
3138 }
3139 }
3140 };
3141
3142 let expanded = quote! {
3143 #[async_trait::async_trait]
3144 impl #impl_generics #crate_path::agent::Agent for #struct_name #ty_generics #where_clause {
3145 type Output = #output_type;
3146
3147 fn expertise(&self) -> &str {
3148 #enhanced_expertise
3149 }
3150
3151 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3152 #agent_init
3154
3155 let max_retries: u32 = #max_retries;
3157 let mut attempts = 0u32;
3158
3159 loop {
3160 attempts += 1;
3161
3162 let result = async {
3164 let response = agent.execute(intent.clone()).await?;
3165
3166 let json_str = #crate_path::extract_json(&response)
3168 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))?;
3169
3170 serde_json::from_str::<Self::Output>(&json_str)
3172 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))
3173 }.await;
3174
3175 match result {
3176 Ok(output) => return Ok(output),
3177 Err(e) if e.is_retryable() && attempts < max_retries => {
3178 log::warn!(
3180 "Agent execution failed (attempt {}/{}): {}. Retrying...",
3181 attempts,
3182 max_retries,
3183 e
3184 );
3185
3186 let delay_ms = 100 * attempts as u64;
3188 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
3189
3190 continue;
3192 }
3193 Err(e) => {
3194 if attempts > 1 {
3195 log::error!(
3196 "Agent execution failed after {} attempts: {}",
3197 attempts,
3198 e
3199 );
3200 }
3201 return Err(e);
3202 }
3203 }
3204 }
3205 }
3206
3207 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3208 #agent_init
3210 agent.is_available().await
3211 }
3212 }
3213 };
3214
3215 TokenStream::from(expanded)
3216}
3217
3218#[proc_macro_attribute]
3233pub fn agent(attr: TokenStream, item: TokenStream) -> TokenStream {
3234 let agent_attrs = match syn::parse::<AgentAttrs>(attr) {
3236 Ok(attrs) => attrs,
3237 Err(e) => return e.to_compile_error().into(),
3238 };
3239
3240 let input = parse_macro_input!(item as DeriveInput);
3242 let struct_name = &input.ident;
3243 let vis = &input.vis;
3244
3245 let expertise = agent_attrs
3246 .expertise
3247 .unwrap_or_else(|| String::from("general AI assistant"));
3248 let output_type = agent_attrs
3249 .output
3250 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3251 let backend = agent_attrs
3252 .backend
3253 .unwrap_or_else(|| String::from("claude"));
3254 let model = agent_attrs.model;
3255 let profile = agent_attrs.profile;
3256
3257 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3259 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3260
3261 let found_crate =
3263 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3264 let crate_path = match found_crate {
3265 FoundCrate::Itself => {
3266 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3267 quote!(::#ident)
3268 }
3269 FoundCrate::Name(name) => {
3270 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3271 quote!(::#ident)
3272 }
3273 };
3274
3275 let inner_generic_name = agent_attrs.inner.unwrap_or_else(|| String::from("A"));
3277 let inner_generic_ident = syn::Ident::new(&inner_generic_name, proc_macro2::Span::call_site());
3278
3279 let default_agent_type = if let Some(ref custom_type) = agent_attrs.default_inner {
3281 let type_path: syn::Type =
3283 syn::parse_str(custom_type).expect("default_inner must be a valid type path");
3284 quote! { #type_path }
3285 } else {
3286 match backend.as_str() {
3288 "gemini" => quote! { #crate_path::agent::impls::GeminiAgent },
3289 _ => quote! { #crate_path::agent::impls::ClaudeCodeAgent },
3290 }
3291 };
3292
3293 let struct_def = quote! {
3295 #vis struct #struct_name<#inner_generic_ident = #default_agent_type> {
3296 inner: #inner_generic_ident,
3297 }
3298 };
3299
3300 let constructors = quote! {
3302 impl<#inner_generic_ident> #struct_name<#inner_generic_ident> {
3303 pub fn new(inner: #inner_generic_ident) -> Self {
3305 Self { inner }
3306 }
3307 }
3308 };
3309
3310 let (backend_constructors, default_impl) = if agent_attrs.default_inner.is_some() {
3312 let default_impl = quote! {
3314 impl Default for #struct_name {
3315 fn default() -> Self {
3316 Self {
3317 inner: <#default_agent_type as Default>::default(),
3318 }
3319 }
3320 }
3321 };
3322 (quote! {}, default_impl)
3323 } else {
3324 let backend_constructors = generate_backend_constructors(
3326 struct_name,
3327 &backend,
3328 model.as_deref(),
3329 profile.as_deref(),
3330 &crate_path,
3331 );
3332 let default_impl = generate_default_impl(
3333 struct_name,
3334 &backend,
3335 model.as_deref(),
3336 profile.as_deref(),
3337 &crate_path,
3338 );
3339 (backend_constructors, default_impl)
3340 };
3341
3342 let enhanced_expertise = if is_string_output {
3344 quote! { #expertise }
3346 } else {
3347 let type_name = quote!(#output_type).to_string();
3349 quote! {
3350 {
3351 use std::sync::OnceLock;
3352 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3353
3354 EXPERTISE_CACHE.get_or_init(|| {
3355 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3357
3358 if schema.is_empty() {
3359 format!(
3361 concat!(
3362 #expertise,
3363 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3364 "Do not include any text outside the JSON object."
3365 ),
3366 #type_name
3367 )
3368 } else {
3369 format!(
3371 concat!(
3372 #expertise,
3373 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3374 ),
3375 schema
3376 )
3377 }
3378 }).as_str()
3379 }
3380 }
3381 };
3382
3383 let agent_impl = quote! {
3385 #[async_trait::async_trait]
3386 impl<#inner_generic_ident> #crate_path::agent::Agent for #struct_name<#inner_generic_ident>
3387 where
3388 #inner_generic_ident: #crate_path::agent::Agent<Output = String>,
3389 {
3390 type Output = #output_type;
3391
3392 fn expertise(&self) -> &str {
3393 #enhanced_expertise
3394 }
3395
3396 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3397 let enhanced_payload = intent.prepend_text(self.expertise());
3399
3400 let response = self.inner.execute(enhanced_payload).await?;
3402
3403 let json_str = #crate_path::extract_json(&response)
3405 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))?;
3406
3407 serde_json::from_str(&json_str)
3409 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))
3410 }
3411
3412 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3413 self.inner.is_available().await
3414 }
3415 }
3416 };
3417
3418 let expanded = quote! {
3419 #struct_def
3420 #constructors
3421 #backend_constructors
3422 #default_impl
3423 #agent_impl
3424 };
3425
3426 TokenStream::from(expanded)
3427}