1use proc_macro::TokenStream;
2use proc_macro_crate::{FoundCrate, crate_name};
3use quote::quote;
4use regex::Regex;
5use syn::{
6 Data, DeriveInput, Meta, Token,
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 punctuated::Punctuated,
10};
11
12fn parse_template_placeholders_with_mode(template: &str) -> Vec<(String, Option<String>)> {
15 let mut placeholders = Vec::new();
16 let mut seen_fields = std::collections::HashSet::new();
17
18 let mode_pattern = Regex::new(r"\{\{\s*(\w+)\s*:\s*(\w+)\s*\}\}").unwrap();
20 for cap in mode_pattern.captures_iter(template) {
21 let field_name = cap[1].to_string();
22 let mode = cap[2].to_string();
23 placeholders.push((field_name.clone(), Some(mode)));
24 seen_fields.insert(field_name);
25 }
26
27 let standard_pattern = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
29 for cap in standard_pattern.captures_iter(template) {
30 let field_name = cap[1].to_string();
31 if !seen_fields.contains(&field_name) {
33 placeholders.push((field_name, None));
34 }
35 }
36
37 placeholders
38}
39
40fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
42 attrs
43 .iter()
44 .filter_map(|attr| {
45 if attr.path().is_ident("doc")
46 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
47 && let syn::Expr::Lit(syn::ExprLit {
48 lit: syn::Lit::Str(lit_str),
49 ..
50 }) = &meta_name_value.value
51 {
52 return Some(lit_str.value());
53 }
54 None
55 })
56 .map(|s| s.trim().to_string())
57 .collect::<Vec<_>>()
58 .join(" ")
59}
60
61fn generate_example_only_parts(
63 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
64 has_default: bool,
65 crate_path: &proc_macro2::TokenStream,
66) -> proc_macro2::TokenStream {
67 let mut field_values = Vec::new();
68
69 for field in fields.iter() {
70 let field_name = field.ident.as_ref().unwrap();
71 let field_name_str = field_name.to_string();
72 let attrs = parse_field_prompt_attrs(&field.attrs);
73
74 if attrs.skip {
76 continue;
77 }
78
79 if let Some(example) = attrs.example {
81 field_values.push(quote! {
83 json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
84 });
85 } else if has_default {
86 field_values.push(quote! {
88 let default_value = serde_json::to_value(&default_instance.#field_name)
89 .unwrap_or(serde_json::Value::Null);
90 json_obj.insert(#field_name_str.to_string(), default_value);
91 });
92 } else {
93 field_values.push(quote! {
95 let value = serde_json::to_value(&self.#field_name)
96 .unwrap_or(serde_json::Value::Null);
97 json_obj.insert(#field_name_str.to_string(), value);
98 });
99 }
100 }
101
102 if has_default {
103 quote! {
104 {
105 let default_instance = Self::default();
106 let mut json_obj = serde_json::Map::new();
107 #(#field_values)*
108 let json_value = serde_json::Value::Object(json_obj);
109 let json_str = serde_json::to_string_pretty(&json_value)
110 .unwrap_or_else(|_| "{}".to_string());
111 vec![#crate_path::prompt::PromptPart::Text(json_str)]
112 }
113 }
114 } else {
115 quote! {
116 {
117 let mut json_obj = serde_json::Map::new();
118 #(#field_values)*
119 let json_value = serde_json::Value::Object(json_obj);
120 let json_str = serde_json::to_string_pretty(&json_value)
121 .unwrap_or_else(|_| "{}".to_string());
122 vec![#crate_path::prompt::PromptPart::Text(json_str)]
123 }
124 }
125 }
126}
127
128fn generate_schema_only_parts(
130 struct_name: &str,
131 struct_docs: &str,
132 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
133 crate_path: &proc_macro2::TokenStream,
134) -> proc_macro2::TokenStream {
135 let mut field_schema_parts = vec![];
136
137 for (i, field) in fields.iter().enumerate() {
139 let field_name = field.ident.as_ref().unwrap();
140 let field_name_str = field_name.to_string();
141 let attrs = parse_field_prompt_attrs(&field.attrs);
142
143 if attrs.skip {
145 continue;
146 }
147
148 let field_docs = extract_doc_comments(&field.attrs);
150
151 let (is_vec, inner_type) = extract_vec_inner_type(&field.ty);
153
154 let remaining_fields = fields
156 .iter()
157 .skip(i + 1)
158 .filter(|f| {
159 let attrs = parse_field_prompt_attrs(&f.attrs);
160 !attrs.skip
161 })
162 .count();
163 let comma = if remaining_fields > 0 { "," } else { "" };
164
165 if is_vec {
166 let comment = if !field_docs.is_empty() {
168 format!(", // {}", field_docs)
169 } else {
170 String::new()
171 };
172
173 field_schema_parts.push(quote! {
174 {
175 let inner_schema = <#inner_type as #crate_path::prompt::ToPrompt>::prompt_schema();
176 if inner_schema.is_empty() {
177 format!(" \"{}\": \"{}[]\"{}{}", #field_name_str, stringify!(#inner_type).to_lowercase(), #comment, #comma)
179 } else {
180 let inner_lines: Vec<&str> = inner_schema.lines()
182 .skip_while(|line| line.starts_with("###") || line.trim() == "{")
183 .take_while(|line| line.trim() != "}")
184 .collect();
185 let inner_content = inner_lines.join("\n");
186 format!(" \"{}\": [\n {{\n{}\n }}\n ]{}{}",
187 #field_name_str,
188 inner_content.lines()
189 .map(|line| format!(" {}", line))
190 .collect::<Vec<_>>()
191 .join("\n"),
192 #comment,
193 #comma
194 )
195 }
196 }
197 });
198 } else {
199 let type_str = format_type_for_schema(&field.ty);
201 let comment = if !field_docs.is_empty() {
202 format!(", // {}", field_docs)
203 } else {
204 String::new()
205 };
206
207 field_schema_parts.push(quote! {
208 format!(" \"{}\": \"{}\"{}{}", #field_name_str, #type_str, #comment, #comma)
209 });
210 }
211 }
212
213 let header = if !struct_docs.is_empty() {
215 format!("### Schema for `{}`\n{}", struct_name, struct_docs)
216 } else {
217 format!("### Schema for `{}`", struct_name)
218 };
219
220 quote! {
221 {
222 let mut lines = vec![#header.to_string(), "{".to_string()];
223 #(lines.push(#field_schema_parts);)*
224 lines.push("}".to_string());
225 vec![#crate_path::prompt::PromptPart::Text(lines.join("\n"))]
226 }
227 }
228}
229
230fn extract_vec_inner_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
232 if let syn::Type::Path(type_path) = ty
233 && let Some(last_segment) = type_path.path.segments.last()
234 && last_segment.ident == "Vec"
235 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
236 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
237 {
238 return (true, Some(inner_type));
239 }
240 (false, None)
241}
242
243fn format_type_for_schema(ty: &syn::Type) -> String {
245 match ty {
247 syn::Type::Path(type_path) => {
248 let path = &type_path.path;
249 if let Some(last_segment) = path.segments.last() {
250 let type_name = last_segment.ident.to_string();
251
252 if type_name == "Option"
254 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
255 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
256 {
257 return format!("{} | null", format_type_for_schema(inner_type));
258 }
259
260 match type_name.as_str() {
262 "String" | "str" => "string".to_string(),
263 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
264 | "u64" | "u128" | "usize" => "number".to_string(),
265 "f32" | "f64" => "number".to_string(),
266 "bool" => "boolean".to_string(),
267 "Vec" => {
268 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
269 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
270 {
271 return format!("{}[]", format_type_for_schema(inner_type));
272 }
273 "array".to_string()
274 }
275 _ => type_name.to_lowercase(),
276 }
277 } else {
278 "unknown".to_string()
279 }
280 }
281 _ => "unknown".to_string(),
282 }
283}
284
285enum PromptAttribute {
287 Skip,
288 Description(String),
289 None,
290}
291
292fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
294 for attr in attrs {
295 if attr.path().is_ident("prompt") {
296 if let Ok(meta_list) = attr.meta.require_list() {
298 let tokens = &meta_list.tokens;
299 let tokens_str = tokens.to_string();
300 if tokens_str == "skip" {
301 return PromptAttribute::Skip;
302 }
303 }
304
305 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
307 return PromptAttribute::Description(lit_str.value());
308 }
309 }
310 }
311 PromptAttribute::None
312}
313
314#[derive(Debug, Default)]
316struct FieldPromptAttrs {
317 skip: bool,
318 rename: Option<String>,
319 format_with: Option<String>,
320 image: bool,
321 example: Option<String>,
322}
323
324fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
326 let mut result = FieldPromptAttrs::default();
327
328 for attr in attrs {
329 if attr.path().is_ident("prompt") {
330 if let Ok(meta_list) = attr.meta.require_list() {
332 if let Ok(metas) =
334 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
335 {
336 for meta in metas {
337 match meta {
338 Meta::Path(path) if path.is_ident("skip") => {
339 result.skip = true;
340 }
341 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
342 if let syn::Expr::Lit(syn::ExprLit {
343 lit: syn::Lit::Str(lit_str),
344 ..
345 }) = nv.value
346 {
347 result.rename = Some(lit_str.value());
348 }
349 }
350 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
351 if let syn::Expr::Lit(syn::ExprLit {
352 lit: syn::Lit::Str(lit_str),
353 ..
354 }) = nv.value
355 {
356 result.format_with = Some(lit_str.value());
357 }
358 }
359 Meta::Path(path) if path.is_ident("image") => {
360 result.image = true;
361 }
362 Meta::NameValue(nv) if nv.path.is_ident("example") => {
363 if let syn::Expr::Lit(syn::ExprLit {
364 lit: syn::Lit::Str(lit_str),
365 ..
366 }) = nv.value
367 {
368 result.example = Some(lit_str.value());
369 }
370 }
371 _ => {}
372 }
373 }
374 } else if meta_list.tokens.to_string() == "skip" {
375 result.skip = true;
377 } else if meta_list.tokens.to_string() == "image" {
378 result.image = true;
380 }
381 }
382 }
383 }
384
385 result
386}
387
388#[proc_macro_derive(ToPrompt, attributes(prompt))]
431pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
432 let input = parse_macro_input!(input as DeriveInput);
433
434 let found_crate =
435 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
436 let crate_path = match found_crate {
437 FoundCrate::Itself => {
438 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
440 quote!(::#ident)
441 }
442 FoundCrate::Name(name) => {
443 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
444 quote!(::#ident)
445 }
446 };
447
448 match &input.data {
450 Data::Enum(data_enum) => {
451 let enum_name = &input.ident;
453 let enum_docs = extract_doc_comments(&input.attrs);
454
455 let mut prompt_lines = Vec::new();
456
457 if !enum_docs.is_empty() {
459 prompt_lines.push(format!("{}: {}", enum_name, enum_docs));
460 } else {
461 prompt_lines.push(format!("{}:", enum_name));
462 }
463 prompt_lines.push(String::new()); prompt_lines.push("Possible values:".to_string());
465
466 for variant in &data_enum.variants {
468 let variant_name = &variant.ident;
469
470 match parse_prompt_attribute(&variant.attrs) {
472 PromptAttribute::Skip => {
473 continue;
475 }
476 PromptAttribute::Description(desc) => {
477 prompt_lines.push(format!("- {}: {}", variant_name, desc));
479 }
480 PromptAttribute::None => {
481 let variant_docs = extract_doc_comments(&variant.attrs);
483 if !variant_docs.is_empty() {
484 prompt_lines.push(format!("- {}: {}", variant_name, variant_docs));
485 } else {
486 prompt_lines.push(format!("- {}", variant_name));
487 }
488 }
489 }
490 }
491
492 let prompt_string = prompt_lines.join("\n");
493 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
494
495 let mut match_arms = Vec::new();
497 for variant in &data_enum.variants {
498 let variant_name = &variant.ident;
499
500 match parse_prompt_attribute(&variant.attrs) {
502 PromptAttribute::Skip => {
503 match_arms.push(quote! {
505 Self::#variant_name => stringify!(#variant_name).to_string()
506 });
507 }
508 PromptAttribute::Description(desc) => {
509 match_arms.push(quote! {
511 Self::#variant_name => format!("{}: {}", stringify!(#variant_name), #desc)
512 });
513 }
514 PromptAttribute::None => {
515 let variant_docs = extract_doc_comments(&variant.attrs);
517 if !variant_docs.is_empty() {
518 match_arms.push(quote! {
519 Self::#variant_name => format!("{}: {}", stringify!(#variant_name), #variant_docs)
520 });
521 } else {
522 match_arms.push(quote! {
523 Self::#variant_name => stringify!(#variant_name).to_string()
524 });
525 }
526 }
527 }
528 }
529
530 let to_prompt_impl = if match_arms.is_empty() {
531 quote! {
533 fn to_prompt(&self) -> String {
534 match *self {}
535 }
536 }
537 } else {
538 quote! {
539 fn to_prompt(&self) -> String {
540 match self {
541 #(#match_arms),*
542 }
543 }
544 }
545 };
546
547 let expanded = quote! {
548 impl #impl_generics #crate_path::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
549 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
550 vec![#crate_path::prompt::PromptPart::Text(self.to_prompt())]
551 }
552
553 #to_prompt_impl
554
555 fn prompt_schema() -> String {
556 #prompt_string.to_string()
557 }
558 }
559 };
560
561 TokenStream::from(expanded)
562 }
563 Data::Struct(data_struct) => {
564 let mut template_attr = None;
566 let mut template_file_attr = None;
567 let mut mode_attr = None;
568 let mut validate_attr = false;
569
570 for attr in &input.attrs {
571 if attr.path().is_ident("prompt") {
572 if let Ok(metas) =
574 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
575 {
576 for meta in metas {
577 match meta {
578 Meta::NameValue(nv) if nv.path.is_ident("template") => {
579 if let syn::Expr::Lit(expr_lit) = nv.value
580 && let syn::Lit::Str(lit_str) = expr_lit.lit
581 {
582 template_attr = Some(lit_str.value());
583 }
584 }
585 Meta::NameValue(nv) if nv.path.is_ident("template_file") => {
586 if let syn::Expr::Lit(expr_lit) = nv.value
587 && let syn::Lit::Str(lit_str) = expr_lit.lit
588 {
589 template_file_attr = Some(lit_str.value());
590 }
591 }
592 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
593 if let syn::Expr::Lit(expr_lit) = nv.value
594 && let syn::Lit::Str(lit_str) = expr_lit.lit
595 {
596 mode_attr = Some(lit_str.value());
597 }
598 }
599 Meta::NameValue(nv) if nv.path.is_ident("validate") => {
600 if let syn::Expr::Lit(expr_lit) = nv.value
601 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
602 {
603 validate_attr = lit_bool.value();
604 }
605 }
606 _ => {}
607 }
608 }
609 }
610 }
611 }
612
613 if template_attr.is_some() && template_file_attr.is_some() {
615 return syn::Error::new(
616 input.ident.span(),
617 "The `template` and `template_file` attributes are mutually exclusive. Please use only one.",
618 ).to_compile_error().into();
619 }
620
621 let template_str = if let Some(file_path) = template_file_attr {
623 let mut full_path = None;
627
628 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
630 let is_trybuild = manifest_dir.contains("target/tests/trybuild");
632
633 if !is_trybuild {
634 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
636 if candidate.exists() {
637 full_path = Some(candidate);
638 }
639 } else {
640 if let Some(target_pos) = manifest_dir.find("/target/tests/trybuild") {
646 let workspace_root = &manifest_dir[..target_pos];
647 let original_macros_dir = std::path::Path::new(workspace_root)
649 .join("crates")
650 .join("llm-toolkit-macros");
651
652 let candidate = original_macros_dir.join(&file_path);
653 if candidate.exists() {
654 full_path = Some(candidate);
655 }
656 }
657 }
658 }
659
660 if full_path.is_none() {
662 let candidate = std::path::Path::new(&file_path).to_path_buf();
663 if candidate.exists() {
664 full_path = Some(candidate);
665 }
666 }
667
668 if full_path.is_none()
671 && let Ok(current_dir) = std::env::current_dir()
672 {
673 let mut search_dir = current_dir.as_path();
674 for _ in 0..10 {
676 let macros_dir = search_dir.join("crates/llm-toolkit-macros");
678 if macros_dir.exists() {
679 let candidate = macros_dir.join(&file_path);
680 if candidate.exists() {
681 full_path = Some(candidate);
682 break;
683 }
684 }
685 let candidate = search_dir.join(&file_path);
687 if candidate.exists() {
688 full_path = Some(candidate);
689 break;
690 }
691 if let Some(parent) = search_dir.parent() {
692 search_dir = parent;
693 } else {
694 break;
695 }
696 }
697 }
698
699 if full_path.is_none() {
701 let mut error_msg = format!(
703 "Template file '{}' not found at compile time.\n\nSearched in:",
704 file_path
705 );
706
707 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
708 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
709 error_msg.push_str(&format!("\n - {}", candidate.display()));
710 }
711
712 if let Ok(current_dir) = std::env::current_dir() {
713 let candidate = current_dir.join(&file_path);
714 error_msg.push_str(&format!("\n - {}", candidate.display()));
715 }
716
717 error_msg.push_str("\n\nPlease ensure:");
718 error_msg.push_str("\n 1. The template file exists");
719 error_msg.push_str("\n 2. The path is relative to CARGO_MANIFEST_DIR");
720 error_msg.push_str("\n 3. There are no typos in the path");
721
722 return syn::Error::new(input.ident.span(), error_msg)
723 .to_compile_error()
724 .into();
725 }
726
727 let final_path = full_path.unwrap();
728
729 match std::fs::read_to_string(&final_path) {
731 Ok(content) => Some(content),
732 Err(e) => {
733 return syn::Error::new(
734 input.ident.span(),
735 format!(
736 "Failed to read template file '{}': {}\n\nPath resolved to: {}",
737 file_path,
738 e,
739 final_path.display()
740 ),
741 )
742 .to_compile_error()
743 .into();
744 }
745 }
746 } else {
747 template_attr
748 };
749
750 if validate_attr && let Some(template) = &template_str {
752 let mut env = minijinja::Environment::new();
754 if let Err(e) = env.add_template("validation", template) {
755 let warning_msg =
757 format!("Template validation warning: Invalid Jinja syntax - {}", e);
758 let warning_ident = syn::Ident::new(
759 "TEMPLATE_VALIDATION_WARNING",
760 proc_macro2::Span::call_site(),
761 );
762 let _warning_tokens = quote! {
763 #[deprecated(note = #warning_msg)]
764 const #warning_ident: () = ();
765 let _ = #warning_ident;
766 };
767 eprintln!("cargo:warning={}", warning_msg);
769 }
770
771 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
773 &fields.named
774 } else {
775 panic!("Template validation is only supported for structs with named fields.");
776 };
777
778 let field_names: std::collections::HashSet<String> = fields
779 .iter()
780 .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
781 .collect();
782
783 let placeholders = parse_template_placeholders_with_mode(template);
785
786 for (placeholder_name, _mode) in &placeholders {
787 if placeholder_name != "self" && !field_names.contains(placeholder_name) {
788 let warning_msg = format!(
789 "Template validation warning: Variable '{}' used in template but not found in struct fields",
790 placeholder_name
791 );
792 eprintln!("cargo:warning={}", warning_msg);
793 }
794 }
795 }
796
797 let name = input.ident;
798 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
799
800 let struct_docs = extract_doc_comments(&input.attrs);
802
803 let is_mode_based =
805 mode_attr.is_some() || (template_str.is_none() && struct_docs.contains("mode"));
806
807 let expanded = if is_mode_based || mode_attr.is_some() {
808 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
810 &fields.named
811 } else {
812 panic!(
813 "Mode-based prompt generation is only supported for structs with named fields."
814 );
815 };
816
817 let struct_name_str = name.to_string();
818
819 let has_default = input.attrs.iter().any(|attr| {
821 if attr.path().is_ident("derive")
822 && let Ok(meta_list) = attr.meta.require_list()
823 {
824 let tokens_str = meta_list.tokens.to_string();
825 tokens_str.contains("Default")
826 } else {
827 false
828 }
829 });
830
831 let schema_parts =
833 generate_schema_only_parts(&struct_name_str, &struct_docs, fields, &crate_path);
834
835 let example_parts = generate_example_only_parts(fields, has_default, &crate_path);
837
838 quote! {
839 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
840 fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<#crate_path::prompt::PromptPart> {
841 match mode {
842 "schema_only" => #schema_parts,
843 "example_only" => #example_parts,
844 "full" | _ => {
845 let mut parts = Vec::new();
847
848 let schema_parts = #schema_parts;
850 parts.extend(schema_parts);
851
852 parts.push(#crate_path::prompt::PromptPart::Text("\n### Example".to_string()));
854 parts.push(#crate_path::prompt::PromptPart::Text(
855 format!("Here is an example of a valid `{}` object:", #struct_name_str)
856 ));
857
858 let example_parts = #example_parts;
860 parts.extend(example_parts);
861
862 parts
863 }
864 }
865 }
866
867 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
868 self.to_prompt_parts_with_mode("full")
869 }
870
871 fn to_prompt(&self) -> String {
872 self.to_prompt_parts()
873 .into_iter()
874 .filter_map(|part| match part {
875 #crate_path::prompt::PromptPart::Text(text) => Some(text),
876 _ => None,
877 })
878 .collect::<Vec<_>>()
879 .join("\n")
880 }
881
882 fn prompt_schema() -> String {
883 use std::sync::OnceLock;
884 static SCHEMA_CACHE: OnceLock<String> = OnceLock::new();
885
886 SCHEMA_CACHE.get_or_init(|| {
887 let schema_parts = #schema_parts;
888 schema_parts
889 .into_iter()
890 .filter_map(|part| match part {
891 #crate_path::prompt::PromptPart::Text(text) => Some(text),
892 _ => None,
893 })
894 .collect::<Vec<_>>()
895 .join("\n")
896 }).clone()
897 }
898 }
899 }
900 } else if let Some(template) = template_str {
901 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
904 &fields.named
905 } else {
906 panic!(
907 "Template prompt generation is only supported for structs with named fields."
908 );
909 };
910
911 let placeholders = parse_template_placeholders_with_mode(&template);
913 let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
915 mode.is_some()
916 && fields
917 .iter()
918 .any(|f| f.ident.as_ref().unwrap() == field_name)
919 });
920
921 let mut image_field_parts = Vec::new();
922 for f in fields.iter() {
923 let field_name = f.ident.as_ref().unwrap();
924 let attrs = parse_field_prompt_attrs(&f.attrs);
925
926 if attrs.image {
927 image_field_parts.push(quote! {
929 parts.extend(self.#field_name.to_prompt_parts());
930 });
931 }
932 }
933
934 if has_mode_syntax {
936 let mut context_fields = Vec::new();
938 let mut modified_template = template.clone();
939
940 for (field_name, mode_opt) in &placeholders {
942 if let Some(mode) = mode_opt {
943 let unique_key = format!("{}__{}", field_name, mode);
945
946 let pattern = format!("{{{{ {}:{} }}}}", field_name, mode);
948 let replacement = format!("{{{{ {} }}}}", unique_key);
949 modified_template = modified_template.replace(&pattern, &replacement);
950
951 let field_ident =
953 syn::Ident::new(field_name, proc_macro2::Span::call_site());
954
955 context_fields.push(quote! {
957 context.insert(
958 #unique_key.to_string(),
959 minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
960 );
961 });
962 }
963 }
964
965 for field in fields.iter() {
967 let field_name = field.ident.as_ref().unwrap();
968 let field_name_str = field_name.to_string();
969
970 let has_mode_entry = placeholders
972 .iter()
973 .any(|(name, mode)| name == &field_name_str && mode.is_some());
974
975 if !has_mode_entry {
976 let is_primitive = match &field.ty {
979 syn::Type::Path(type_path) => {
980 if let Some(segment) = type_path.path.segments.last() {
981 let type_name = segment.ident.to_string();
982 matches!(
983 type_name.as_str(),
984 "String"
985 | "str"
986 | "i8"
987 | "i16"
988 | "i32"
989 | "i64"
990 | "i128"
991 | "isize"
992 | "u8"
993 | "u16"
994 | "u32"
995 | "u64"
996 | "u128"
997 | "usize"
998 | "f32"
999 | "f64"
1000 | "bool"
1001 | "char"
1002 )
1003 } else {
1004 false
1005 }
1006 }
1007 _ => false,
1008 };
1009
1010 if is_primitive {
1011 context_fields.push(quote! {
1012 context.insert(
1013 #field_name_str.to_string(),
1014 minijinja::Value::from_serialize(&self.#field_name)
1015 );
1016 });
1017 } else {
1018 context_fields.push(quote! {
1020 context.insert(
1021 #field_name_str.to_string(),
1022 minijinja::Value::from(self.#field_name.to_prompt())
1023 );
1024 });
1025 }
1026 }
1027 }
1028
1029 quote! {
1030 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1031 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1032 let mut parts = Vec::new();
1033
1034 #(#image_field_parts)*
1036
1037 let text = {
1039 let mut env = minijinja::Environment::new();
1040 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1041 panic!("Failed to parse template: {}", e)
1042 });
1043
1044 let tmpl = env.get_template("prompt").unwrap();
1045
1046 let mut context = std::collections::HashMap::new();
1047 #(#context_fields)*
1048
1049 tmpl.render(context).unwrap_or_else(|e| {
1050 format!("Failed to render prompt: {}", e)
1051 })
1052 };
1053
1054 if !text.is_empty() {
1055 parts.push(#crate_path::prompt::PromptPart::Text(text));
1056 }
1057
1058 parts
1059 }
1060
1061 fn to_prompt(&self) -> String {
1062 let mut env = minijinja::Environment::new();
1064 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1065 panic!("Failed to parse template: {}", e)
1066 });
1067
1068 let tmpl = env.get_template("prompt").unwrap();
1069
1070 let mut context = std::collections::HashMap::new();
1071 #(#context_fields)*
1072
1073 tmpl.render(context).unwrap_or_else(|e| {
1074 format!("Failed to render prompt: {}", e)
1075 })
1076 }
1077
1078 fn prompt_schema() -> String {
1079 String::new() }
1081 }
1082 }
1083 } else {
1084 quote! {
1086 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1087 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1088 let mut parts = Vec::new();
1089
1090 #(#image_field_parts)*
1092
1093 let text = #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1095 format!("Failed to render prompt: {}", e)
1096 });
1097 if !text.is_empty() {
1098 parts.push(#crate_path::prompt::PromptPart::Text(text));
1099 }
1100
1101 parts
1102 }
1103
1104 fn to_prompt(&self) -> String {
1105 #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1106 format!("Failed to render prompt: {}", e)
1107 })
1108 }
1109
1110 fn prompt_schema() -> String {
1111 String::new() }
1113 }
1114 }
1115 }
1116 } else {
1117 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1120 &fields.named
1121 } else {
1122 panic!(
1123 "Default prompt generation is only supported for structs with named fields."
1124 );
1125 };
1126
1127 let mut text_field_parts = Vec::new();
1129 let mut image_field_parts = Vec::new();
1130
1131 for f in fields.iter() {
1132 let field_name = f.ident.as_ref().unwrap();
1133 let attrs = parse_field_prompt_attrs(&f.attrs);
1134
1135 if attrs.skip {
1137 continue;
1138 }
1139
1140 if attrs.image {
1141 image_field_parts.push(quote! {
1143 parts.extend(self.#field_name.to_prompt_parts());
1144 });
1145 } else {
1146 let key = if let Some(rename) = attrs.rename {
1152 rename
1153 } else {
1154 let doc_comment = extract_doc_comments(&f.attrs);
1155 if !doc_comment.is_empty() {
1156 doc_comment
1157 } else {
1158 field_name.to_string()
1159 }
1160 };
1161
1162 let value_expr = if let Some(format_with) = attrs.format_with {
1164 let func_path: syn::Path =
1166 syn::parse_str(&format_with).unwrap_or_else(|_| {
1167 panic!("Invalid function path: {}", format_with)
1168 });
1169 quote! { #func_path(&self.#field_name) }
1170 } else {
1171 quote! { self.#field_name.to_prompt() }
1172 };
1173
1174 text_field_parts.push(quote! {
1175 text_parts.push(format!("{}: {}", #key, #value_expr));
1176 });
1177 }
1178 }
1179
1180 quote! {
1182 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1183 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1184 let mut parts = Vec::new();
1185
1186 #(#image_field_parts)*
1188
1189 let mut text_parts = Vec::new();
1191 #(#text_field_parts)*
1192
1193 if !text_parts.is_empty() {
1194 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1195 }
1196
1197 parts
1198 }
1199
1200 fn to_prompt(&self) -> String {
1201 let mut text_parts = Vec::new();
1202 #(#text_field_parts)*
1203 text_parts.join("\n")
1204 }
1205
1206 fn prompt_schema() -> String {
1207 String::new() }
1209 }
1210 }
1211 };
1212
1213 TokenStream::from(expanded)
1214 }
1215 Data::Union(_) => {
1216 panic!("`#[derive(ToPrompt)]` is not supported for unions");
1217 }
1218 }
1219}
1220
1221#[derive(Debug, Clone)]
1223struct TargetInfo {
1224 name: String,
1225 template: Option<String>,
1226 field_configs: std::collections::HashMap<String, FieldTargetConfig>,
1227}
1228
1229#[derive(Debug, Clone, Default)]
1231struct FieldTargetConfig {
1232 skip: bool,
1233 rename: Option<String>,
1234 format_with: Option<String>,
1235 image: bool,
1236 include_only: bool, }
1238
1239fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
1241 let mut configs = Vec::new();
1242
1243 for attr in attrs {
1244 if attr.path().is_ident("prompt_for")
1245 && let Ok(meta_list) = attr.meta.require_list()
1246 {
1247 if meta_list.tokens.to_string() == "skip" {
1249 let config = FieldTargetConfig {
1251 skip: true,
1252 ..Default::default()
1253 };
1254 configs.push(("*".to_string(), config));
1255 } else if let Ok(metas) =
1256 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1257 {
1258 let mut target_name = None;
1259 let mut config = FieldTargetConfig::default();
1260
1261 for meta in metas {
1262 match meta {
1263 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1264 if let syn::Expr::Lit(syn::ExprLit {
1265 lit: syn::Lit::Str(lit_str),
1266 ..
1267 }) = nv.value
1268 {
1269 target_name = Some(lit_str.value());
1270 }
1271 }
1272 Meta::Path(path) if path.is_ident("skip") => {
1273 config.skip = true;
1274 }
1275 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
1276 if let syn::Expr::Lit(syn::ExprLit {
1277 lit: syn::Lit::Str(lit_str),
1278 ..
1279 }) = nv.value
1280 {
1281 config.rename = Some(lit_str.value());
1282 }
1283 }
1284 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
1285 if let syn::Expr::Lit(syn::ExprLit {
1286 lit: syn::Lit::Str(lit_str),
1287 ..
1288 }) = nv.value
1289 {
1290 config.format_with = Some(lit_str.value());
1291 }
1292 }
1293 Meta::Path(path) if path.is_ident("image") => {
1294 config.image = true;
1295 }
1296 _ => {}
1297 }
1298 }
1299
1300 if let Some(name) = target_name {
1301 config.include_only = true;
1302 configs.push((name, config));
1303 }
1304 }
1305 }
1306 }
1307
1308 configs
1309}
1310
1311fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
1313 let mut targets = Vec::new();
1314
1315 for attr in attrs {
1316 if attr.path().is_ident("prompt_for")
1317 && let Ok(meta_list) = attr.meta.require_list()
1318 && let Ok(metas) =
1319 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1320 {
1321 let mut target_name = None;
1322 let mut template = None;
1323
1324 for meta in metas {
1325 match meta {
1326 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1327 if let syn::Expr::Lit(syn::ExprLit {
1328 lit: syn::Lit::Str(lit_str),
1329 ..
1330 }) = nv.value
1331 {
1332 target_name = Some(lit_str.value());
1333 }
1334 }
1335 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1336 if let syn::Expr::Lit(syn::ExprLit {
1337 lit: syn::Lit::Str(lit_str),
1338 ..
1339 }) = nv.value
1340 {
1341 template = Some(lit_str.value());
1342 }
1343 }
1344 _ => {}
1345 }
1346 }
1347
1348 if let Some(name) = target_name {
1349 targets.push(TargetInfo {
1350 name,
1351 template,
1352 field_configs: std::collections::HashMap::new(),
1353 });
1354 }
1355 }
1356 }
1357
1358 targets
1359}
1360
1361#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
1362pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
1363 let input = parse_macro_input!(input as DeriveInput);
1364
1365 let found_crate =
1366 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1367 let crate_path = match found_crate {
1368 FoundCrate::Itself => {
1369 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1371 quote!(::#ident)
1372 }
1373 FoundCrate::Name(name) => {
1374 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1375 quote!(::#ident)
1376 }
1377 };
1378
1379 let data_struct = match &input.data {
1381 Data::Struct(data) => data,
1382 _ => {
1383 return syn::Error::new(
1384 input.ident.span(),
1385 "`#[derive(ToPromptSet)]` is only supported for structs",
1386 )
1387 .to_compile_error()
1388 .into();
1389 }
1390 };
1391
1392 let fields = match &data_struct.fields {
1393 syn::Fields::Named(fields) => &fields.named,
1394 _ => {
1395 return syn::Error::new(
1396 input.ident.span(),
1397 "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
1398 )
1399 .to_compile_error()
1400 .into();
1401 }
1402 };
1403
1404 let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
1406
1407 for field in fields.iter() {
1409 let field_name = field.ident.as_ref().unwrap().to_string();
1410 let field_configs = parse_prompt_for_attrs(&field.attrs);
1411
1412 for (target_name, config) in field_configs {
1413 if target_name == "*" {
1414 for target in &mut targets {
1416 target
1417 .field_configs
1418 .entry(field_name.clone())
1419 .or_insert_with(FieldTargetConfig::default)
1420 .skip = config.skip;
1421 }
1422 } else {
1423 let target_exists = targets.iter().any(|t| t.name == target_name);
1425 if !target_exists {
1426 targets.push(TargetInfo {
1428 name: target_name.clone(),
1429 template: None,
1430 field_configs: std::collections::HashMap::new(),
1431 });
1432 }
1433
1434 let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
1435
1436 target.field_configs.insert(field_name.clone(), config);
1437 }
1438 }
1439 }
1440
1441 let mut match_arms = Vec::new();
1443
1444 for target in &targets {
1445 let target_name = &target.name;
1446
1447 if let Some(template_str) = &target.template {
1448 let mut image_parts = Vec::new();
1450
1451 for field in fields.iter() {
1452 let field_name = field.ident.as_ref().unwrap();
1453 let field_name_str = field_name.to_string();
1454
1455 if let Some(config) = target.field_configs.get(&field_name_str)
1456 && config.image
1457 {
1458 image_parts.push(quote! {
1459 parts.extend(self.#field_name.to_prompt_parts());
1460 });
1461 }
1462 }
1463
1464 match_arms.push(quote! {
1465 #target_name => {
1466 let mut parts = Vec::new();
1467
1468 #(#image_parts)*
1469
1470 let text = #crate_path::prompt::render_prompt(#template_str, self)
1471 .map_err(|e| #crate_path::prompt::PromptSetError::RenderFailed {
1472 target: #target_name.to_string(),
1473 source: e,
1474 })?;
1475
1476 if !text.is_empty() {
1477 parts.push(#crate_path::prompt::PromptPart::Text(text));
1478 }
1479
1480 Ok(parts)
1481 }
1482 });
1483 } else {
1484 let mut text_field_parts = Vec::new();
1486 let mut image_field_parts = Vec::new();
1487
1488 for field in fields.iter() {
1489 let field_name = field.ident.as_ref().unwrap();
1490 let field_name_str = field_name.to_string();
1491
1492 let config = target.field_configs.get(&field_name_str);
1494
1495 if let Some(cfg) = config
1497 && cfg.skip
1498 {
1499 continue;
1500 }
1501
1502 let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
1506 let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
1507 .iter()
1508 .any(|(name, _)| name != "*");
1509
1510 if has_any_target_specific_config && !is_explicitly_for_this_target {
1511 continue;
1512 }
1513
1514 if let Some(cfg) = config {
1515 if cfg.image {
1516 image_field_parts.push(quote! {
1517 parts.extend(self.#field_name.to_prompt_parts());
1518 });
1519 } else {
1520 let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
1521
1522 let value_expr = if let Some(format_with) = &cfg.format_with {
1523 match syn::parse_str::<syn::Path>(format_with) {
1525 Ok(func_path) => quote! { #func_path(&self.#field_name) },
1526 Err(_) => {
1527 let error_msg = format!(
1529 "Invalid function path in format_with: '{}'",
1530 format_with
1531 );
1532 quote! {
1533 compile_error!(#error_msg);
1534 String::new()
1535 }
1536 }
1537 }
1538 } else {
1539 quote! { self.#field_name.to_prompt() }
1540 };
1541
1542 text_field_parts.push(quote! {
1543 text_parts.push(format!("{}: {}", #key, #value_expr));
1544 });
1545 }
1546 } else {
1547 text_field_parts.push(quote! {
1549 text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
1550 });
1551 }
1552 }
1553
1554 match_arms.push(quote! {
1555 #target_name => {
1556 let mut parts = Vec::new();
1557
1558 #(#image_field_parts)*
1559
1560 let mut text_parts = Vec::new();
1561 #(#text_field_parts)*
1562
1563 if !text_parts.is_empty() {
1564 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1565 }
1566
1567 Ok(parts)
1568 }
1569 });
1570 }
1571 }
1572
1573 let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
1575
1576 match_arms.push(quote! {
1578 _ => {
1579 let available = vec![#(#target_names.to_string()),*];
1580 Err(#crate_path::prompt::PromptSetError::TargetNotFound {
1581 target: target.to_string(),
1582 available,
1583 })
1584 }
1585 });
1586
1587 let struct_name = &input.ident;
1588 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1589
1590 let expanded = quote! {
1591 impl #impl_generics #crate_path::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
1592 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<#crate_path::prompt::PromptPart>, #crate_path::prompt::PromptSetError> {
1593 match target {
1594 #(#match_arms)*
1595 }
1596 }
1597 }
1598 };
1599
1600 TokenStream::from(expanded)
1601}
1602
1603struct TypeList {
1605 types: Punctuated<syn::Type, Token![,]>,
1606}
1607
1608impl Parse for TypeList {
1609 fn parse(input: ParseStream) -> syn::Result<Self> {
1610 Ok(TypeList {
1611 types: Punctuated::parse_terminated(input)?,
1612 })
1613 }
1614}
1615
1616#[proc_macro]
1640pub fn examples_section(input: TokenStream) -> TokenStream {
1641 let input = parse_macro_input!(input as TypeList);
1642
1643 let found_crate =
1644 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1645 let _crate_path = match found_crate {
1646 FoundCrate::Itself => quote!(crate),
1647 FoundCrate::Name(name) => {
1648 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1649 quote!(::#ident)
1650 }
1651 };
1652
1653 let mut type_sections = Vec::new();
1655
1656 for ty in input.types.iter() {
1657 let type_name_str = quote!(#ty).to_string();
1659
1660 type_sections.push(quote! {
1662 {
1663 let type_name = #type_name_str;
1664 let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
1665 format!("---\n#### `{}`\n{}", type_name, json_example)
1666 }
1667 });
1668 }
1669
1670 let expanded = quote! {
1672 {
1673 let mut sections = Vec::new();
1674 sections.push("---".to_string());
1675 sections.push("### Examples".to_string());
1676 sections.push("".to_string());
1677 sections.push("Here are examples of the data structures you should use.".to_string());
1678 sections.push("".to_string());
1679
1680 #(sections.push(#type_sections);)*
1681
1682 sections.push("---".to_string());
1683
1684 sections.join("\n")
1685 }
1686 };
1687
1688 TokenStream::from(expanded)
1689}
1690
1691fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
1693 for attr in attrs {
1694 if attr.path().is_ident("prompt_for")
1695 && let Ok(meta_list) = attr.meta.require_list()
1696 && let Ok(metas) =
1697 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1698 {
1699 let mut target_type = None;
1700 let mut template = None;
1701
1702 for meta in metas {
1703 match meta {
1704 Meta::NameValue(nv) if nv.path.is_ident("target") => {
1705 if let syn::Expr::Lit(syn::ExprLit {
1706 lit: syn::Lit::Str(lit_str),
1707 ..
1708 }) = nv.value
1709 {
1710 target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
1712 }
1713 }
1714 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1715 if let syn::Expr::Lit(syn::ExprLit {
1716 lit: syn::Lit::Str(lit_str),
1717 ..
1718 }) = nv.value
1719 {
1720 template = Some(lit_str.value());
1721 }
1722 }
1723 _ => {}
1724 }
1725 }
1726
1727 if let (Some(target), Some(tmpl)) = (target_type, template) {
1728 return (target, tmpl);
1729 }
1730 }
1731 }
1732
1733 panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
1734}
1735
1736#[proc_macro_attribute]
1770pub fn define_intent(_attr: TokenStream, item: TokenStream) -> TokenStream {
1771 let input = parse_macro_input!(item as DeriveInput);
1772
1773 let found_crate =
1774 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1775 let crate_path = match found_crate {
1776 FoundCrate::Itself => {
1777 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1779 quote!(::#ident)
1780 }
1781 FoundCrate::Name(name) => {
1782 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1783 quote!(::#ident)
1784 }
1785 };
1786
1787 let enum_data = match &input.data {
1789 Data::Enum(data) => data,
1790 _ => {
1791 return syn::Error::new(
1792 input.ident.span(),
1793 "`#[define_intent]` can only be applied to enums",
1794 )
1795 .to_compile_error()
1796 .into();
1797 }
1798 };
1799
1800 let mut prompt_template = None;
1802 let mut extractor_tag = None;
1803 let mut mode = None;
1804
1805 for attr in &input.attrs {
1806 if attr.path().is_ident("intent")
1807 && let Ok(metas) =
1808 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1809 {
1810 for meta in metas {
1811 match meta {
1812 Meta::NameValue(nv) if nv.path.is_ident("prompt") => {
1813 if let syn::Expr::Lit(syn::ExprLit {
1814 lit: syn::Lit::Str(lit_str),
1815 ..
1816 }) = nv.value
1817 {
1818 prompt_template = Some(lit_str.value());
1819 }
1820 }
1821 Meta::NameValue(nv) if nv.path.is_ident("extractor_tag") => {
1822 if let syn::Expr::Lit(syn::ExprLit {
1823 lit: syn::Lit::Str(lit_str),
1824 ..
1825 }) = nv.value
1826 {
1827 extractor_tag = Some(lit_str.value());
1828 }
1829 }
1830 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
1831 if let syn::Expr::Lit(syn::ExprLit {
1832 lit: syn::Lit::Str(lit_str),
1833 ..
1834 }) = nv.value
1835 {
1836 mode = Some(lit_str.value());
1837 }
1838 }
1839 _ => {}
1840 }
1841 }
1842 }
1843 }
1844
1845 let mode = mode.unwrap_or_else(|| "single".to_string());
1847
1848 if mode != "single" && mode != "multi_tag" {
1850 return syn::Error::new(
1851 input.ident.span(),
1852 "`mode` must be either \"single\" or \"multi_tag\"",
1853 )
1854 .to_compile_error()
1855 .into();
1856 }
1857
1858 let prompt_template = match prompt_template {
1860 Some(p) => p,
1861 None => {
1862 return syn::Error::new(
1863 input.ident.span(),
1864 "`#[intent(...)]` attribute must include `prompt = \"...\"`",
1865 )
1866 .to_compile_error()
1867 .into();
1868 }
1869 };
1870
1871 if mode == "multi_tag" {
1873 let enum_name = &input.ident;
1874 let actions_doc = generate_multi_tag_actions_doc(&enum_data.variants);
1875 return generate_multi_tag_output(
1876 &input,
1877 enum_name,
1878 enum_data,
1879 prompt_template,
1880 actions_doc,
1881 );
1882 }
1883
1884 let extractor_tag = match extractor_tag {
1886 Some(t) => t,
1887 None => {
1888 return syn::Error::new(
1889 input.ident.span(),
1890 "`#[intent(...)]` attribute must include `extractor_tag = \"...\"`",
1891 )
1892 .to_compile_error()
1893 .into();
1894 }
1895 };
1896
1897 let enum_name = &input.ident;
1899 let enum_docs = extract_doc_comments(&input.attrs);
1900
1901 let mut intents_doc_lines = Vec::new();
1902
1903 if !enum_docs.is_empty() {
1905 intents_doc_lines.push(format!("{}: {}", enum_name, enum_docs));
1906 } else {
1907 intents_doc_lines.push(format!("{}:", enum_name));
1908 }
1909 intents_doc_lines.push(String::new()); intents_doc_lines.push("Possible values:".to_string());
1911
1912 for variant in &enum_data.variants {
1914 let variant_name = &variant.ident;
1915 let variant_docs = extract_doc_comments(&variant.attrs);
1916
1917 if !variant_docs.is_empty() {
1918 intents_doc_lines.push(format!("- {}: {}", variant_name, variant_docs));
1919 } else {
1920 intents_doc_lines.push(format!("- {}", variant_name));
1921 }
1922 }
1923
1924 let intents_doc_str = intents_doc_lines.join("\n");
1925
1926 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
1928 let user_variables: Vec<String> = placeholders
1929 .iter()
1930 .filter_map(|(name, _)| {
1931 if name != "intents_doc" {
1932 Some(name.clone())
1933 } else {
1934 None
1935 }
1936 })
1937 .collect();
1938
1939 let enum_name_str = enum_name.to_string();
1941 let snake_case_name = to_snake_case(&enum_name_str);
1942 let function_name = syn::Ident::new(
1943 &format!("build_{}_prompt", snake_case_name),
1944 proc_macro2::Span::call_site(),
1945 );
1946
1947 let function_params: Vec<proc_macro2::TokenStream> = user_variables
1949 .iter()
1950 .map(|var| {
1951 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
1952 quote! { #ident: &str }
1953 })
1954 .collect();
1955
1956 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
1958 .iter()
1959 .map(|var| {
1960 let var_str = var.clone();
1961 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
1962 quote! {
1963 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
1964 }
1965 })
1966 .collect();
1967
1968 let converted_template = prompt_template.clone();
1970
1971 let extractor_name = syn::Ident::new(
1973 &format!("{}Extractor", enum_name),
1974 proc_macro2::Span::call_site(),
1975 );
1976
1977 let filtered_attrs: Vec<_> = input
1979 .attrs
1980 .iter()
1981 .filter(|attr| !attr.path().is_ident("intent"))
1982 .collect();
1983
1984 let vis = &input.vis;
1986 let generics = &input.generics;
1987 let variants = &enum_data.variants;
1988 let enum_output = quote! {
1989 #(#filtered_attrs)*
1990 #vis enum #enum_name #generics {
1991 #variants
1992 }
1993 };
1994
1995 let expanded = quote! {
1997 #enum_output
1999
2000 pub fn #function_name(#(#function_params),*) -> String {
2002 let mut env = minijinja::Environment::new();
2003 env.add_template("prompt", #converted_template)
2004 .expect("Failed to parse intent prompt template");
2005
2006 let tmpl = env.get_template("prompt").unwrap();
2007
2008 let mut __template_context = std::collections::HashMap::new();
2009
2010 __template_context.insert("intents_doc".to_string(), minijinja::Value::from(#intents_doc_str));
2012
2013 #(#context_insertions)*
2015
2016 tmpl.render(&__template_context)
2017 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2018 }
2019
2020 pub struct #extractor_name;
2022
2023 impl #extractor_name {
2024 pub const EXTRACTOR_TAG: &'static str = #extractor_tag;
2025 }
2026
2027 impl #crate_path::intent::IntentExtractor<#enum_name> for #extractor_name {
2028 fn extract_intent(&self, response: &str) -> Result<#enum_name, #crate_path::intent::IntentExtractionError> {
2029 #crate_path::intent::extract_intent_from_response(response, Self::EXTRACTOR_TAG)
2031 }
2032 }
2033 };
2034
2035 TokenStream::from(expanded)
2036}
2037
2038fn to_snake_case(s: &str) -> String {
2040 let mut result = String::new();
2041 let mut prev_upper = false;
2042
2043 for (i, ch) in s.chars().enumerate() {
2044 if ch.is_uppercase() {
2045 if i > 0 && !prev_upper {
2046 result.push('_');
2047 }
2048 result.push(ch.to_lowercase().next().unwrap());
2049 prev_upper = true;
2050 } else {
2051 result.push(ch);
2052 prev_upper = false;
2053 }
2054 }
2055
2056 result
2057}
2058
2059#[derive(Debug, Default)]
2061struct ActionAttrs {
2062 tag: Option<String>,
2063}
2064
2065fn parse_action_attrs(attrs: &[syn::Attribute]) -> ActionAttrs {
2066 let mut result = ActionAttrs::default();
2067
2068 for attr in attrs {
2069 if attr.path().is_ident("action")
2070 && let Ok(meta_list) = attr.meta.require_list()
2071 && let Ok(metas) =
2072 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2073 {
2074 for meta in metas {
2075 if let Meta::NameValue(nv) = meta
2076 && nv.path.is_ident("tag")
2077 && let syn::Expr::Lit(syn::ExprLit {
2078 lit: syn::Lit::Str(lit_str),
2079 ..
2080 }) = nv.value
2081 {
2082 result.tag = Some(lit_str.value());
2083 }
2084 }
2085 }
2086 }
2087
2088 result
2089}
2090
2091#[derive(Debug, Default)]
2093struct FieldActionAttrs {
2094 is_attribute: bool,
2095 is_inner_text: bool,
2096}
2097
2098fn parse_field_action_attrs(attrs: &[syn::Attribute]) -> FieldActionAttrs {
2099 let mut result = FieldActionAttrs::default();
2100
2101 for attr in attrs {
2102 if attr.path().is_ident("action")
2103 && let Ok(meta_list) = attr.meta.require_list()
2104 {
2105 let tokens_str = meta_list.tokens.to_string();
2106 if tokens_str == "attribute" {
2107 result.is_attribute = true;
2108 } else if tokens_str == "inner_text" {
2109 result.is_inner_text = true;
2110 }
2111 }
2112 }
2113
2114 result
2115}
2116
2117fn generate_multi_tag_actions_doc(
2119 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2120) -> String {
2121 let mut doc_lines = Vec::new();
2122
2123 for variant in variants {
2124 let action_attrs = parse_action_attrs(&variant.attrs);
2125
2126 if let Some(tag) = action_attrs.tag {
2127 let variant_docs = extract_doc_comments(&variant.attrs);
2128
2129 match &variant.fields {
2130 syn::Fields::Unit => {
2131 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2133 }
2134 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
2135 doc_lines.push(format!("- `<{}>...</{}>`: {}", tag, tag, variant_docs));
2137 }
2138 syn::Fields::Named(fields) => {
2139 let mut attrs_str = Vec::new();
2141 let mut has_inner_text = false;
2142
2143 for field in &fields.named {
2144 let field_name = field.ident.as_ref().unwrap();
2145 let field_attrs = parse_field_action_attrs(&field.attrs);
2146
2147 if field_attrs.is_attribute {
2148 attrs_str.push(format!("{}=\"...\"", field_name));
2149 } else if field_attrs.is_inner_text {
2150 has_inner_text = true;
2151 }
2152 }
2153
2154 let attrs_part = if !attrs_str.is_empty() {
2155 format!(" {}", attrs_str.join(" "))
2156 } else {
2157 String::new()
2158 };
2159
2160 if has_inner_text {
2161 doc_lines.push(format!(
2162 "- `<{}{}>...</{}>`: {}",
2163 tag, attrs_part, tag, variant_docs
2164 ));
2165 } else if !attrs_str.is_empty() {
2166 doc_lines.push(format!("- `<{}{} />`: {}", tag, attrs_part, variant_docs));
2167 } else {
2168 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2169 }
2170
2171 for field in &fields.named {
2173 let field_name = field.ident.as_ref().unwrap();
2174 let field_attrs = parse_field_action_attrs(&field.attrs);
2175 let field_docs = extract_doc_comments(&field.attrs);
2176
2177 if field_attrs.is_attribute {
2178 doc_lines
2179 .push(format!(" - `{}` (attribute): {}", field_name, field_docs));
2180 } else if field_attrs.is_inner_text {
2181 doc_lines
2182 .push(format!(" - `{}` (inner_text): {}", field_name, field_docs));
2183 }
2184 }
2185 }
2186 _ => {
2187 }
2189 }
2190 }
2191 }
2192
2193 doc_lines.join("\n")
2194}
2195
2196fn generate_tags_regex(
2198 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2199) -> String {
2200 let mut tag_names = Vec::new();
2201
2202 for variant in variants {
2203 let action_attrs = parse_action_attrs(&variant.attrs);
2204 if let Some(tag) = action_attrs.tag {
2205 tag_names.push(tag);
2206 }
2207 }
2208
2209 if tag_names.is_empty() {
2210 return String::new();
2211 }
2212
2213 let tags_pattern = tag_names.join("|");
2214 format!(
2217 r"(?is)<(?:{})\b[^>]*/>|<(?:{})\b[^>]*>.*?</(?:{})>",
2218 tags_pattern, tags_pattern, tags_pattern
2219 )
2220}
2221
2222fn generate_multi_tag_output(
2224 input: &DeriveInput,
2225 enum_name: &syn::Ident,
2226 enum_data: &syn::DataEnum,
2227 prompt_template: String,
2228 actions_doc: String,
2229) -> TokenStream {
2230 let found_crate =
2231 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2232 let crate_path = match found_crate {
2233 FoundCrate::Itself => {
2234 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2236 quote!(::#ident)
2237 }
2238 FoundCrate::Name(name) => {
2239 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2240 quote!(::#ident)
2241 }
2242 };
2243
2244 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2246 let user_variables: Vec<String> = placeholders
2247 .iter()
2248 .filter_map(|(name, _)| {
2249 if name != "actions_doc" {
2250 Some(name.clone())
2251 } else {
2252 None
2253 }
2254 })
2255 .collect();
2256
2257 let enum_name_str = enum_name.to_string();
2259 let snake_case_name = to_snake_case(&enum_name_str);
2260 let function_name = syn::Ident::new(
2261 &format!("build_{}_prompt", snake_case_name),
2262 proc_macro2::Span::call_site(),
2263 );
2264
2265 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2267 .iter()
2268 .map(|var| {
2269 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2270 quote! { #ident: &str }
2271 })
2272 .collect();
2273
2274 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2276 .iter()
2277 .map(|var| {
2278 let var_str = var.clone();
2279 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2280 quote! {
2281 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2282 }
2283 })
2284 .collect();
2285
2286 let extractor_name = syn::Ident::new(
2288 &format!("{}Extractor", enum_name),
2289 proc_macro2::Span::call_site(),
2290 );
2291
2292 let filtered_attrs: Vec<_> = input
2294 .attrs
2295 .iter()
2296 .filter(|attr| !attr.path().is_ident("intent"))
2297 .collect();
2298
2299 let filtered_variants: Vec<proc_macro2::TokenStream> = enum_data
2301 .variants
2302 .iter()
2303 .map(|variant| {
2304 let variant_name = &variant.ident;
2305 let variant_attrs: Vec<_> = variant
2306 .attrs
2307 .iter()
2308 .filter(|attr| !attr.path().is_ident("action"))
2309 .collect();
2310 let fields = &variant.fields;
2311
2312 let filtered_fields = match fields {
2314 syn::Fields::Named(named_fields) => {
2315 let filtered: Vec<_> = named_fields
2316 .named
2317 .iter()
2318 .map(|field| {
2319 let field_name = &field.ident;
2320 let field_type = &field.ty;
2321 let field_vis = &field.vis;
2322 let filtered_attrs: Vec<_> = field
2323 .attrs
2324 .iter()
2325 .filter(|attr| !attr.path().is_ident("action"))
2326 .collect();
2327 quote! {
2328 #(#filtered_attrs)*
2329 #field_vis #field_name: #field_type
2330 }
2331 })
2332 .collect();
2333 quote! { { #(#filtered,)* } }
2334 }
2335 syn::Fields::Unnamed(unnamed_fields) => {
2336 let types: Vec<_> = unnamed_fields
2337 .unnamed
2338 .iter()
2339 .map(|field| {
2340 let field_type = &field.ty;
2341 quote! { #field_type }
2342 })
2343 .collect();
2344 quote! { (#(#types),*) }
2345 }
2346 syn::Fields::Unit => quote! {},
2347 };
2348
2349 quote! {
2350 #(#variant_attrs)*
2351 #variant_name #filtered_fields
2352 }
2353 })
2354 .collect();
2355
2356 let vis = &input.vis;
2357 let generics = &input.generics;
2358
2359 let parsing_arms = generate_parsing_arms(&enum_data.variants, enum_name);
2361
2362 let tags_regex = generate_tags_regex(&enum_data.variants);
2364
2365 let expanded = quote! {
2366 #(#filtered_attrs)*
2368 #vis enum #enum_name #generics {
2369 #(#filtered_variants),*
2370 }
2371
2372 pub fn #function_name(#(#function_params),*) -> String {
2374 let mut env = minijinja::Environment::new();
2375 env.add_template("prompt", #prompt_template)
2376 .expect("Failed to parse intent prompt template");
2377
2378 let tmpl = env.get_template("prompt").unwrap();
2379
2380 let mut __template_context = std::collections::HashMap::new();
2381
2382 __template_context.insert("actions_doc".to_string(), minijinja::Value::from(#actions_doc));
2384
2385 #(#context_insertions)*
2387
2388 tmpl.render(&__template_context)
2389 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2390 }
2391
2392 pub struct #extractor_name;
2394
2395 impl #extractor_name {
2396 fn parse_single_action(&self, text: &str) -> Option<#enum_name> {
2397 use ::quick_xml::events::Event;
2398 use ::quick_xml::Reader;
2399
2400 let mut actions = Vec::new();
2401 let mut reader = Reader::from_str(text);
2402 reader.config_mut().trim_text(true);
2403
2404 let mut buf = Vec::new();
2405
2406 loop {
2407 match reader.read_event_into(&mut buf) {
2408 Ok(Event::Start(e)) => {
2409 let owned_e = e.into_owned();
2410 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2411 let is_empty = false;
2412
2413 #parsing_arms
2414 }
2415 Ok(Event::Empty(e)) => {
2416 let owned_e = e.into_owned();
2417 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2418 let is_empty = true;
2419
2420 #parsing_arms
2421 }
2422 Ok(Event::Eof) => break,
2423 Err(_) => {
2424 break;
2426 }
2427 _ => {}
2428 }
2429 buf.clear();
2430 }
2431
2432 actions.into_iter().next()
2433 }
2434
2435 pub fn extract_actions(&self, text: &str) -> Result<Vec<#enum_name>, #crate_path::intent::IntentError> {
2436 use ::quick_xml::events::Event;
2437 use ::quick_xml::Reader;
2438
2439 let mut actions = Vec::new();
2440 let mut reader = Reader::from_str(text);
2441 reader.config_mut().trim_text(true);
2442
2443 let mut buf = Vec::new();
2444
2445 loop {
2446 match reader.read_event_into(&mut buf) {
2447 Ok(Event::Start(e)) => {
2448 let owned_e = e.into_owned();
2449 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2450 let is_empty = false;
2451
2452 #parsing_arms
2453 }
2454 Ok(Event::Empty(e)) => {
2455 let owned_e = e.into_owned();
2456 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2457 let is_empty = true;
2458
2459 #parsing_arms
2460 }
2461 Ok(Event::Eof) => break,
2462 Err(_) => {
2463 break;
2465 }
2466 _ => {}
2467 }
2468 buf.clear();
2469 }
2470
2471 Ok(actions)
2472 }
2473
2474 pub fn transform_actions<F>(&self, text: &str, mut transformer: F) -> String
2475 where
2476 F: FnMut(#enum_name) -> String,
2477 {
2478 use ::regex::Regex;
2479
2480 let regex_pattern = #tags_regex;
2481 if regex_pattern.is_empty() {
2482 return text.to_string();
2483 }
2484
2485 let re = Regex::new(®ex_pattern).unwrap_or_else(|e| {
2486 panic!("Failed to compile regex for action tags: {}", e);
2487 });
2488
2489 re.replace_all(text, |caps: &::regex::Captures| {
2490 let matched = caps.get(0).map(|m| m.as_str()).unwrap_or("");
2491
2492 if let Some(action) = self.parse_single_action(matched) {
2494 transformer(action)
2495 } else {
2496 matched.to_string()
2498 }
2499 }).to_string()
2500 }
2501
2502 pub fn strip_actions(&self, text: &str) -> String {
2503 self.transform_actions(text, |_| String::new())
2504 }
2505 }
2506 };
2507
2508 TokenStream::from(expanded)
2509}
2510
2511fn generate_parsing_arms(
2513 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2514 enum_name: &syn::Ident,
2515) -> proc_macro2::TokenStream {
2516 let mut arms = Vec::new();
2517
2518 for variant in variants {
2519 let variant_name = &variant.ident;
2520 let action_attrs = parse_action_attrs(&variant.attrs);
2521
2522 if let Some(tag) = action_attrs.tag {
2523 match &variant.fields {
2524 syn::Fields::Unit => {
2525 arms.push(quote! {
2527 if &tag_name == #tag {
2528 actions.push(#enum_name::#variant_name);
2529 }
2530 });
2531 }
2532 syn::Fields::Unnamed(_fields) => {
2533 arms.push(quote! {
2535 if &tag_name == #tag && !is_empty {
2536 match reader.read_text(owned_e.name()) {
2538 Ok(text) => {
2539 actions.push(#enum_name::#variant_name(text.to_string()));
2540 }
2541 Err(_) => {
2542 actions.push(#enum_name::#variant_name(String::new()));
2544 }
2545 }
2546 }
2547 });
2548 }
2549 syn::Fields::Named(fields) => {
2550 let mut field_names = Vec::new();
2552 let mut has_inner_text_field = None;
2553
2554 for field in &fields.named {
2555 let field_name = field.ident.as_ref().unwrap();
2556 let field_attrs = parse_field_action_attrs(&field.attrs);
2557
2558 if field_attrs.is_attribute {
2559 field_names.push(field_name.clone());
2560 } else if field_attrs.is_inner_text {
2561 has_inner_text_field = Some(field_name.clone());
2562 }
2563 }
2564
2565 if let Some(inner_text_field) = has_inner_text_field {
2566 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2569 quote! {
2570 let mut #field_name = String::new();
2571 for attr in owned_e.attributes() {
2572 if let Ok(attr) = attr {
2573 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2574 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2575 break;
2576 }
2577 }
2578 }
2579 }
2580 }).collect();
2581
2582 arms.push(quote! {
2583 if &tag_name == #tag {
2584 #(#attr_extractions)*
2585
2586 if is_empty {
2588 let #inner_text_field = String::new();
2589 actions.push(#enum_name::#variant_name {
2590 #(#field_names,)*
2591 #inner_text_field,
2592 });
2593 } else {
2594 match reader.read_text(owned_e.name()) {
2596 Ok(text) => {
2597 let #inner_text_field = text.to_string();
2598 actions.push(#enum_name::#variant_name {
2599 #(#field_names,)*
2600 #inner_text_field,
2601 });
2602 }
2603 Err(_) => {
2604 let #inner_text_field = String::new();
2606 actions.push(#enum_name::#variant_name {
2607 #(#field_names,)*
2608 #inner_text_field,
2609 });
2610 }
2611 }
2612 }
2613 }
2614 });
2615 } else {
2616 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2618 quote! {
2619 let mut #field_name = String::new();
2620 for attr in owned_e.attributes() {
2621 if let Ok(attr) = attr {
2622 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2623 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2624 break;
2625 }
2626 }
2627 }
2628 }
2629 }).collect();
2630
2631 arms.push(quote! {
2632 if &tag_name == #tag {
2633 #(#attr_extractions)*
2634 actions.push(#enum_name::#variant_name {
2635 #(#field_names),*
2636 });
2637 }
2638 });
2639 }
2640 }
2641 }
2642 }
2643 }
2644
2645 quote! {
2646 #(#arms)*
2647 }
2648}
2649
2650#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
2652pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
2653 let input = parse_macro_input!(input as DeriveInput);
2654
2655 let found_crate =
2656 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2657 let crate_path = match found_crate {
2658 FoundCrate::Itself => {
2659 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2661 quote!(::#ident)
2662 }
2663 FoundCrate::Name(name) => {
2664 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2665 quote!(::#ident)
2666 }
2667 };
2668
2669 let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
2671
2672 let struct_name = &input.ident;
2673 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2674
2675 let placeholders = parse_template_placeholders_with_mode(&template);
2677
2678 let mut converted_template = template.clone();
2680 let mut context_fields = Vec::new();
2681
2682 let fields = match &input.data {
2684 Data::Struct(data_struct) => match &data_struct.fields {
2685 syn::Fields::Named(fields) => &fields.named,
2686 _ => panic!("ToPromptFor is only supported for structs with named fields"),
2687 },
2688 _ => panic!("ToPromptFor is only supported for structs"),
2689 };
2690
2691 let has_mode_support = input.attrs.iter().any(|attr| {
2693 if attr.path().is_ident("prompt")
2694 && let Ok(metas) =
2695 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2696 {
2697 for meta in metas {
2698 if let Meta::NameValue(nv) = meta
2699 && nv.path.is_ident("mode")
2700 {
2701 return true;
2702 }
2703 }
2704 }
2705 false
2706 });
2707
2708 for (placeholder_name, mode_opt) in &placeholders {
2710 if placeholder_name == "self" {
2711 if let Some(specific_mode) = mode_opt {
2712 let unique_key = format!("self__{}", specific_mode);
2714
2715 let pattern = format!("{{{{ self:{} }}}}", specific_mode);
2717 let replacement = format!("{{{{ {} }}}}", unique_key);
2718 converted_template = converted_template.replace(&pattern, &replacement);
2719
2720 context_fields.push(quote! {
2722 context.insert(
2723 #unique_key.to_string(),
2724 minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
2725 );
2726 });
2727 } else {
2728 if has_mode_support {
2731 context_fields.push(quote! {
2733 context.insert(
2734 "self".to_string(),
2735 minijinja::Value::from(self.to_prompt_with_mode(mode))
2736 );
2737 });
2738 } else {
2739 context_fields.push(quote! {
2741 context.insert(
2742 "self".to_string(),
2743 minijinja::Value::from(self.to_prompt())
2744 );
2745 });
2746 }
2747 }
2748 } else {
2749 let field_exists = fields.iter().any(|f| {
2752 f.ident
2753 .as_ref()
2754 .is_some_and(|ident| ident == placeholder_name)
2755 });
2756
2757 if field_exists {
2758 let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
2759
2760 context_fields.push(quote! {
2764 context.insert(
2765 #placeholder_name.to_string(),
2766 minijinja::Value::from_serialize(&self.#field_ident)
2767 );
2768 });
2769 }
2770 }
2772 }
2773
2774 let expanded = quote! {
2775 impl #impl_generics #crate_path::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
2776 where
2777 #target_type: serde::Serialize,
2778 {
2779 fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
2780 let mut env = minijinja::Environment::new();
2782 env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
2783 panic!("Failed to parse template: {}", e)
2784 });
2785
2786 let tmpl = env.get_template("prompt").unwrap();
2787
2788 let mut context = std::collections::HashMap::new();
2790 context.insert(
2792 "self".to_string(),
2793 minijinja::Value::from_serialize(self)
2794 );
2795 context.insert(
2797 "target".to_string(),
2798 minijinja::Value::from_serialize(target)
2799 );
2800 #(#context_fields)*
2801
2802 tmpl.render(context).unwrap_or_else(|e| {
2804 format!("Failed to render prompt: {}", e)
2805 })
2806 }
2807 }
2808 };
2809
2810 TokenStream::from(expanded)
2811}
2812
2813struct AgentAttrs {
2819 expertise: Option<String>,
2820 output: Option<syn::Type>,
2821 backend: Option<String>,
2822 model: Option<String>,
2823 inner: Option<String>,
2824 default_inner: Option<String>,
2825 max_retries: Option<u32>,
2826 profile: Option<String>,
2827}
2828
2829impl Parse for AgentAttrs {
2830 fn parse(input: ParseStream) -> syn::Result<Self> {
2831 let mut expertise = None;
2832 let mut output = None;
2833 let mut backend = None;
2834 let mut model = None;
2835 let mut inner = None;
2836 let mut default_inner = None;
2837 let mut max_retries = None;
2838 let mut profile = None;
2839
2840 let pairs = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
2841
2842 for meta in pairs {
2843 match meta {
2844 Meta::NameValue(nv) if nv.path.is_ident("expertise") => {
2845 if let syn::Expr::Lit(syn::ExprLit {
2846 lit: syn::Lit::Str(lit_str),
2847 ..
2848 }) = &nv.value
2849 {
2850 expertise = Some(lit_str.value());
2851 }
2852 }
2853 Meta::NameValue(nv) if nv.path.is_ident("output") => {
2854 if let syn::Expr::Lit(syn::ExprLit {
2855 lit: syn::Lit::Str(lit_str),
2856 ..
2857 }) = &nv.value
2858 {
2859 let ty: syn::Type = syn::parse_str(&lit_str.value())?;
2860 output = Some(ty);
2861 }
2862 }
2863 Meta::NameValue(nv) if nv.path.is_ident("backend") => {
2864 if let syn::Expr::Lit(syn::ExprLit {
2865 lit: syn::Lit::Str(lit_str),
2866 ..
2867 }) = &nv.value
2868 {
2869 backend = Some(lit_str.value());
2870 }
2871 }
2872 Meta::NameValue(nv) if nv.path.is_ident("model") => {
2873 if let syn::Expr::Lit(syn::ExprLit {
2874 lit: syn::Lit::Str(lit_str),
2875 ..
2876 }) = &nv.value
2877 {
2878 model = Some(lit_str.value());
2879 }
2880 }
2881 Meta::NameValue(nv) if nv.path.is_ident("inner") => {
2882 if let syn::Expr::Lit(syn::ExprLit {
2883 lit: syn::Lit::Str(lit_str),
2884 ..
2885 }) = &nv.value
2886 {
2887 inner = Some(lit_str.value());
2888 }
2889 }
2890 Meta::NameValue(nv) if nv.path.is_ident("default_inner") => {
2891 if let syn::Expr::Lit(syn::ExprLit {
2892 lit: syn::Lit::Str(lit_str),
2893 ..
2894 }) = &nv.value
2895 {
2896 default_inner = Some(lit_str.value());
2897 }
2898 }
2899 Meta::NameValue(nv) if nv.path.is_ident("max_retries") => {
2900 if let syn::Expr::Lit(syn::ExprLit {
2901 lit: syn::Lit::Int(lit_int),
2902 ..
2903 }) = &nv.value
2904 {
2905 max_retries = Some(lit_int.base10_parse()?);
2906 }
2907 }
2908 Meta::NameValue(nv) if nv.path.is_ident("profile") => {
2909 if let syn::Expr::Lit(syn::ExprLit {
2910 lit: syn::Lit::Str(lit_str),
2911 ..
2912 }) = &nv.value
2913 {
2914 profile = Some(lit_str.value());
2915 }
2916 }
2917 _ => {}
2918 }
2919 }
2920
2921 Ok(AgentAttrs {
2922 expertise,
2923 output,
2924 backend,
2925 model,
2926 inner,
2927 default_inner,
2928 max_retries,
2929 profile,
2930 })
2931 }
2932}
2933
2934fn parse_agent_attrs(attrs: &[syn::Attribute]) -> syn::Result<AgentAttrs> {
2936 for attr in attrs {
2937 if attr.path().is_ident("agent") {
2938 return attr.parse_args::<AgentAttrs>();
2939 }
2940 }
2941
2942 Ok(AgentAttrs {
2943 expertise: None,
2944 output: None,
2945 backend: None,
2946 model: None,
2947 inner: None,
2948 default_inner: None,
2949 max_retries: None,
2950 profile: None,
2951 })
2952}
2953
2954fn generate_backend_constructors(
2956 struct_name: &syn::Ident,
2957 backend: &str,
2958 _model: Option<&str>,
2959 _profile: Option<&str>,
2960 crate_path: &proc_macro2::TokenStream,
2961) -> proc_macro2::TokenStream {
2962 match backend {
2963 "claude" => {
2964 quote! {
2965 impl #struct_name {
2966 pub fn with_claude() -> Self {
2968 Self::new(#crate_path::agent::impls::ClaudeCodeAgent::new())
2969 }
2970
2971 pub fn with_claude_model(model: &str) -> Self {
2973 Self::new(
2974 #crate_path::agent::impls::ClaudeCodeAgent::new()
2975 .with_model_str(model)
2976 )
2977 }
2978 }
2979 }
2980 }
2981 "gemini" => {
2982 quote! {
2983 impl #struct_name {
2984 pub fn with_gemini() -> Self {
2986 Self::new(#crate_path::agent::impls::GeminiAgent::new())
2987 }
2988
2989 pub fn with_gemini_model(model: &str) -> Self {
2991 Self::new(
2992 #crate_path::agent::impls::GeminiAgent::new()
2993 .with_model_str(model)
2994 )
2995 }
2996 }
2997 }
2998 }
2999 _ => quote! {},
3000 }
3001}
3002
3003fn generate_default_impl(
3005 struct_name: &syn::Ident,
3006 backend: &str,
3007 model: Option<&str>,
3008 profile: Option<&str>,
3009 crate_path: &proc_macro2::TokenStream,
3010) -> proc_macro2::TokenStream {
3011 let profile_expr = if let Some(profile_str) = profile {
3013 match profile_str.to_lowercase().as_str() {
3014 "creative" => quote! { #crate_path::agent::ExecutionProfile::Creative },
3015 "balanced" => quote! { #crate_path::agent::ExecutionProfile::Balanced },
3016 "deterministic" => quote! { #crate_path::agent::ExecutionProfile::Deterministic },
3017 _ => quote! { #crate_path::agent::ExecutionProfile::Balanced }, }
3019 } else {
3020 quote! { #crate_path::agent::ExecutionProfile::default() }
3021 };
3022
3023 let agent_init = match backend {
3024 "gemini" => {
3025 let mut builder = quote! { #crate_path::agent::impls::GeminiAgent::new() };
3026
3027 if let Some(model_str) = model {
3028 builder = quote! { #builder.with_model_str(#model_str) };
3029 }
3030
3031 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3032 builder
3033 }
3034 _ => {
3035 let mut builder = quote! { #crate_path::agent::impls::ClaudeCodeAgent::new() };
3037
3038 if let Some(model_str) = model {
3039 builder = quote! { #builder.with_model_str(#model_str) };
3040 }
3041
3042 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3043 builder
3044 }
3045 };
3046
3047 quote! {
3048 impl Default for #struct_name {
3049 fn default() -> Self {
3050 Self::new(#agent_init)
3051 }
3052 }
3053 }
3054}
3055
3056#[proc_macro_derive(Agent, attributes(agent))]
3065pub fn derive_agent(input: TokenStream) -> TokenStream {
3066 let input = parse_macro_input!(input as DeriveInput);
3067 let struct_name = &input.ident;
3068
3069 let agent_attrs = match parse_agent_attrs(&input.attrs) {
3071 Ok(attrs) => attrs,
3072 Err(e) => return e.to_compile_error().into(),
3073 };
3074
3075 let expertise = agent_attrs
3076 .expertise
3077 .unwrap_or_else(|| String::from("general AI assistant"));
3078 let output_type = agent_attrs
3079 .output
3080 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3081 let backend = agent_attrs
3082 .backend
3083 .unwrap_or_else(|| String::from("claude"));
3084 let model = agent_attrs.model;
3085 let _profile = agent_attrs.profile; let max_retries = agent_attrs.max_retries.unwrap_or(3); let found_crate =
3090 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3091 let crate_path = match found_crate {
3092 FoundCrate::Itself => {
3093 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3095 quote!(::#ident)
3096 }
3097 FoundCrate::Name(name) => {
3098 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3099 quote!(::#ident)
3100 }
3101 };
3102
3103 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
3104
3105 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3107 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3108
3109 let enhanced_expertise = if is_string_output {
3111 quote! { #expertise }
3113 } else {
3114 let type_name = quote!(#output_type).to_string();
3116 quote! {
3117 {
3118 use std::sync::OnceLock;
3119 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3120
3121 EXPERTISE_CACHE.get_or_init(|| {
3122 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3124
3125 if schema.is_empty() {
3126 format!(
3128 concat!(
3129 #expertise,
3130 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3131 "Do not include any text outside the JSON object."
3132 ),
3133 #type_name
3134 )
3135 } else {
3136 format!(
3138 concat!(
3139 #expertise,
3140 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3141 ),
3142 schema
3143 )
3144 }
3145 }).as_str()
3146 }
3147 }
3148 };
3149
3150 let agent_init = match backend.as_str() {
3152 "gemini" => {
3153 if let Some(model_str) = model {
3154 quote! {
3155 use #crate_path::agent::impls::GeminiAgent;
3156 let agent = GeminiAgent::new().with_model_str(#model_str);
3157 }
3158 } else {
3159 quote! {
3160 use #crate_path::agent::impls::GeminiAgent;
3161 let agent = GeminiAgent::new();
3162 }
3163 }
3164 }
3165 "claude" => {
3166 if let Some(model_str) = model {
3167 quote! {
3168 use #crate_path::agent::impls::ClaudeCodeAgent;
3169 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3170 }
3171 } else {
3172 quote! {
3173 use #crate_path::agent::impls::ClaudeCodeAgent;
3174 let agent = ClaudeCodeAgent::new();
3175 }
3176 }
3177 }
3178 _ => {
3179 if let Some(model_str) = model {
3181 quote! {
3182 use #crate_path::agent::impls::ClaudeCodeAgent;
3183 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3184 }
3185 } else {
3186 quote! {
3187 use #crate_path::agent::impls::ClaudeCodeAgent;
3188 let agent = ClaudeCodeAgent::new();
3189 }
3190 }
3191 }
3192 };
3193
3194 let expanded = quote! {
3195 #[async_trait::async_trait]
3196 impl #impl_generics #crate_path::agent::Agent for #struct_name #ty_generics #where_clause {
3197 type Output = #output_type;
3198
3199 fn expertise(&self) -> &str {
3200 #enhanced_expertise
3201 }
3202
3203 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3204 #agent_init
3206
3207 let max_retries: u32 = #max_retries;
3209 let mut attempts = 0u32;
3210
3211 loop {
3212 attempts += 1;
3213
3214 let result = async {
3216 let response = agent.execute(intent.clone()).await?;
3217
3218 let json_str = #crate_path::extract_json(&response)
3220 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))?;
3221
3222 serde_json::from_str::<Self::Output>(&json_str)
3224 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))
3225 }.await;
3226
3227 match result {
3228 Ok(output) => return Ok(output),
3229 Err(e) if e.is_retryable() && attempts < max_retries => {
3230 log::warn!(
3232 "Agent execution failed (attempt {}/{}): {}. Retrying...",
3233 attempts,
3234 max_retries,
3235 e
3236 );
3237
3238 let delay_ms = 100 * attempts as u64;
3240 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
3241
3242 continue;
3244 }
3245 Err(e) => {
3246 if attempts > 1 {
3247 log::error!(
3248 "Agent execution failed after {} attempts: {}",
3249 attempts,
3250 e
3251 );
3252 }
3253 return Err(e);
3254 }
3255 }
3256 }
3257 }
3258
3259 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3260 #agent_init
3262 agent.is_available().await
3263 }
3264 }
3265 };
3266
3267 TokenStream::from(expanded)
3268}
3269
3270#[proc_macro_attribute]
3285pub fn agent(attr: TokenStream, item: TokenStream) -> TokenStream {
3286 let agent_attrs = match syn::parse::<AgentAttrs>(attr) {
3288 Ok(attrs) => attrs,
3289 Err(e) => return e.to_compile_error().into(),
3290 };
3291
3292 let input = parse_macro_input!(item as DeriveInput);
3294 let struct_name = &input.ident;
3295 let vis = &input.vis;
3296
3297 let expertise = agent_attrs
3298 .expertise
3299 .unwrap_or_else(|| String::from("general AI assistant"));
3300 let output_type = agent_attrs
3301 .output
3302 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3303 let backend = agent_attrs
3304 .backend
3305 .unwrap_or_else(|| String::from("claude"));
3306 let model = agent_attrs.model;
3307 let profile = agent_attrs.profile;
3308
3309 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3311 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3312
3313 let found_crate =
3315 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3316 let crate_path = match found_crate {
3317 FoundCrate::Itself => {
3318 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3319 quote!(::#ident)
3320 }
3321 FoundCrate::Name(name) => {
3322 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3323 quote!(::#ident)
3324 }
3325 };
3326
3327 let inner_generic_name = agent_attrs.inner.unwrap_or_else(|| String::from("A"));
3329 let inner_generic_ident = syn::Ident::new(&inner_generic_name, proc_macro2::Span::call_site());
3330
3331 let default_agent_type = if let Some(ref custom_type) = agent_attrs.default_inner {
3333 let type_path: syn::Type =
3335 syn::parse_str(custom_type).expect("default_inner must be a valid type path");
3336 quote! { #type_path }
3337 } else {
3338 match backend.as_str() {
3340 "gemini" => quote! { #crate_path::agent::impls::GeminiAgent },
3341 _ => quote! { #crate_path::agent::impls::ClaudeCodeAgent },
3342 }
3343 };
3344
3345 let struct_def = quote! {
3347 #vis struct #struct_name<#inner_generic_ident = #default_agent_type> {
3348 inner: #inner_generic_ident,
3349 }
3350 };
3351
3352 let constructors = quote! {
3354 impl<#inner_generic_ident> #struct_name<#inner_generic_ident> {
3355 pub fn new(inner: #inner_generic_ident) -> Self {
3357 Self { inner }
3358 }
3359 }
3360 };
3361
3362 let (backend_constructors, default_impl) = if agent_attrs.default_inner.is_some() {
3364 let default_impl = quote! {
3366 impl Default for #struct_name {
3367 fn default() -> Self {
3368 Self {
3369 inner: <#default_agent_type as Default>::default(),
3370 }
3371 }
3372 }
3373 };
3374 (quote! {}, default_impl)
3375 } else {
3376 let backend_constructors = generate_backend_constructors(
3378 struct_name,
3379 &backend,
3380 model.as_deref(),
3381 profile.as_deref(),
3382 &crate_path,
3383 );
3384 let default_impl = generate_default_impl(
3385 struct_name,
3386 &backend,
3387 model.as_deref(),
3388 profile.as_deref(),
3389 &crate_path,
3390 );
3391 (backend_constructors, default_impl)
3392 };
3393
3394 let enhanced_expertise = if is_string_output {
3396 quote! { #expertise }
3398 } else {
3399 let type_name = quote!(#output_type).to_string();
3401 quote! {
3402 {
3403 use std::sync::OnceLock;
3404 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3405
3406 EXPERTISE_CACHE.get_or_init(|| {
3407 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3409
3410 if schema.is_empty() {
3411 format!(
3413 concat!(
3414 #expertise,
3415 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3416 "Do not include any text outside the JSON object."
3417 ),
3418 #type_name
3419 )
3420 } else {
3421 format!(
3423 concat!(
3424 #expertise,
3425 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3426 ),
3427 schema
3428 )
3429 }
3430 }).as_str()
3431 }
3432 }
3433 };
3434
3435 let agent_impl = quote! {
3437 #[async_trait::async_trait]
3438 impl<#inner_generic_ident> #crate_path::agent::Agent for #struct_name<#inner_generic_ident>
3439 where
3440 #inner_generic_ident: #crate_path::agent::Agent<Output = String>,
3441 {
3442 type Output = #output_type;
3443
3444 fn expertise(&self) -> &str {
3445 #enhanced_expertise
3446 }
3447
3448 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3449 let enhanced_payload = intent.prepend_text(self.expertise());
3451
3452 let response = self.inner.execute(enhanced_payload).await?;
3454
3455 let json_str = #crate_path::extract_json(&response)
3457 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))?;
3458
3459 serde_json::from_str(&json_str)
3461 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))
3462 }
3463
3464 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3465 self.inner.is_available().await
3466 }
3467 }
3468 };
3469
3470 let expanded = quote! {
3471 #struct_def
3472 #constructors
3473 #backend_constructors
3474 #default_impl
3475 #agent_impl
3476 };
3477
3478 TokenStream::from(expanded)
3479}