1use proc_macro::TokenStream;
2use proc_macro_crate::{FoundCrate, crate_name};
3use quote::quote;
4use regex::Regex;
5use syn::{
6 Data, DeriveInput, Meta, Token,
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 punctuated::Punctuated,
10};
11
12fn parse_template_placeholders_with_mode(template: &str) -> Vec<(String, Option<String>)> {
15 let mut placeholders = Vec::new();
16 let mut seen_fields = std::collections::HashSet::new();
17
18 let mode_pattern = Regex::new(r"\{\{\s*(\w+)\s*:\s*(\w+)\s*\}\}").unwrap();
20 for cap in mode_pattern.captures_iter(template) {
21 let field_name = cap[1].to_string();
22 let mode = cap[2].to_string();
23 placeholders.push((field_name.clone(), Some(mode)));
24 seen_fields.insert(field_name);
25 }
26
27 let standard_pattern = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
29 for cap in standard_pattern.captures_iter(template) {
30 let field_name = cap[1].to_string();
31 if !seen_fields.contains(&field_name) {
33 placeholders.push((field_name, None));
34 }
35 }
36
37 placeholders
38}
39
40fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
42 attrs
43 .iter()
44 .filter_map(|attr| {
45 if attr.path().is_ident("doc")
46 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
47 && let syn::Expr::Lit(syn::ExprLit {
48 lit: syn::Lit::Str(lit_str),
49 ..
50 }) = &meta_name_value.value
51 {
52 return Some(lit_str.value());
53 }
54 None
55 })
56 .map(|s| s.trim().to_string())
57 .collect::<Vec<_>>()
58 .join(" ")
59}
60
61fn generate_example_only_parts(
63 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
64 has_default: bool,
65 crate_path: &proc_macro2::TokenStream,
66) -> proc_macro2::TokenStream {
67 let mut field_values = Vec::new();
68
69 for field in fields.iter() {
70 let field_name = field.ident.as_ref().unwrap();
71 let field_name_str = field_name.to_string();
72 let attrs = parse_field_prompt_attrs(&field.attrs);
73
74 if attrs.skip {
76 continue;
77 }
78
79 if let Some(example) = attrs.example {
81 field_values.push(quote! {
83 json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
84 });
85 } else if has_default {
86 field_values.push(quote! {
88 let default_value = serde_json::to_value(&default_instance.#field_name)
89 .unwrap_or(serde_json::Value::Null);
90 json_obj.insert(#field_name_str.to_string(), default_value);
91 });
92 } else {
93 field_values.push(quote! {
95 let value = serde_json::to_value(&self.#field_name)
96 .unwrap_or(serde_json::Value::Null);
97 json_obj.insert(#field_name_str.to_string(), value);
98 });
99 }
100 }
101
102 if has_default {
103 quote! {
104 {
105 let default_instance = Self::default();
106 let mut json_obj = serde_json::Map::new();
107 #(#field_values)*
108 let json_value = serde_json::Value::Object(json_obj);
109 let json_str = serde_json::to_string_pretty(&json_value)
110 .unwrap_or_else(|_| "{}".to_string());
111 vec![#crate_path::prompt::PromptPart::Text(json_str)]
112 }
113 }
114 } else {
115 quote! {
116 {
117 let mut json_obj = serde_json::Map::new();
118 #(#field_values)*
119 let json_value = serde_json::Value::Object(json_obj);
120 let json_str = serde_json::to_string_pretty(&json_value)
121 .unwrap_or_else(|_| "{}".to_string());
122 vec![#crate_path::prompt::PromptPart::Text(json_str)]
123 }
124 }
125 }
126}
127
128fn generate_schema_only_parts(
130 struct_name: &str,
131 struct_docs: &str,
132 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
133 crate_path: &proc_macro2::TokenStream,
134) -> proc_macro2::TokenStream {
135 let mut schema_lines = vec![];
136
137 if !struct_docs.is_empty() {
139 schema_lines.push(format!("### Schema for `{}`\n{}", struct_name, struct_docs));
140 } else {
141 schema_lines.push(format!("### Schema for `{}`", struct_name));
142 }
143
144 schema_lines.push("{".to_string());
145
146 for (i, field) in fields.iter().enumerate() {
148 let field_name = field.ident.as_ref().unwrap();
149 let attrs = parse_field_prompt_attrs(&field.attrs);
150
151 if attrs.skip {
153 continue;
154 }
155
156 let field_docs = extract_doc_comments(&field.attrs);
158
159 let type_str = format_type_for_schema(&field.ty);
161
162 let mut field_line = format!(" \"{}\": \"{}\"", field_name, type_str);
164
165 if !field_docs.is_empty() {
167 field_line.push_str(&format!(", // {}", field_docs));
168 }
169
170 let remaining_fields = fields
172 .iter()
173 .skip(i + 1)
174 .filter(|f| {
175 let attrs = parse_field_prompt_attrs(&f.attrs);
176 !attrs.skip
177 })
178 .count();
179
180 if remaining_fields > 0 {
181 field_line.push(',');
182 }
183
184 schema_lines.push(field_line);
185 }
186
187 schema_lines.push("}".to_string());
188
189 let schema_str = schema_lines.join("\n");
190
191 quote! {
192 vec![#crate_path::prompt::PromptPart::Text(#schema_str.to_string())]
193 }
194}
195
196fn format_type_for_schema(ty: &syn::Type) -> String {
198 match ty {
200 syn::Type::Path(type_path) => {
201 let path = &type_path.path;
202 if let Some(last_segment) = path.segments.last() {
203 let type_name = last_segment.ident.to_string();
204
205 if type_name == "Option"
207 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
208 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
209 {
210 return format!("{} | null", format_type_for_schema(inner_type));
211 }
212
213 match type_name.as_str() {
215 "String" | "str" => "string".to_string(),
216 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
217 | "u64" | "u128" | "usize" => "number".to_string(),
218 "f32" | "f64" => "number".to_string(),
219 "bool" => "boolean".to_string(),
220 "Vec" => {
221 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
222 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
223 {
224 return format!("{}[]", format_type_for_schema(inner_type));
225 }
226 "array".to_string()
227 }
228 _ => type_name.to_lowercase(),
229 }
230 } else {
231 "unknown".to_string()
232 }
233 }
234 _ => "unknown".to_string(),
235 }
236}
237
238enum PromptAttribute {
240 Skip,
241 Description(String),
242 None,
243}
244
245fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
247 for attr in attrs {
248 if attr.path().is_ident("prompt") {
249 if let Ok(meta_list) = attr.meta.require_list() {
251 let tokens = &meta_list.tokens;
252 let tokens_str = tokens.to_string();
253 if tokens_str == "skip" {
254 return PromptAttribute::Skip;
255 }
256 }
257
258 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
260 return PromptAttribute::Description(lit_str.value());
261 }
262 }
263 }
264 PromptAttribute::None
265}
266
267#[derive(Debug, Default)]
269struct FieldPromptAttrs {
270 skip: bool,
271 rename: Option<String>,
272 format_with: Option<String>,
273 image: bool,
274 example: Option<String>,
275}
276
277fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
279 let mut result = FieldPromptAttrs::default();
280
281 for attr in attrs {
282 if attr.path().is_ident("prompt") {
283 if let Ok(meta_list) = attr.meta.require_list() {
285 if let Ok(metas) =
287 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
288 {
289 for meta in metas {
290 match meta {
291 Meta::Path(path) if path.is_ident("skip") => {
292 result.skip = true;
293 }
294 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
295 if let syn::Expr::Lit(syn::ExprLit {
296 lit: syn::Lit::Str(lit_str),
297 ..
298 }) = nv.value
299 {
300 result.rename = Some(lit_str.value());
301 }
302 }
303 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
304 if let syn::Expr::Lit(syn::ExprLit {
305 lit: syn::Lit::Str(lit_str),
306 ..
307 }) = nv.value
308 {
309 result.format_with = Some(lit_str.value());
310 }
311 }
312 Meta::Path(path) if path.is_ident("image") => {
313 result.image = true;
314 }
315 Meta::NameValue(nv) if nv.path.is_ident("example") => {
316 if let syn::Expr::Lit(syn::ExprLit {
317 lit: syn::Lit::Str(lit_str),
318 ..
319 }) = nv.value
320 {
321 result.example = Some(lit_str.value());
322 }
323 }
324 _ => {}
325 }
326 }
327 } else if meta_list.tokens.to_string() == "skip" {
328 result.skip = true;
330 } else if meta_list.tokens.to_string() == "image" {
331 result.image = true;
333 }
334 }
335 }
336 }
337
338 result
339}
340
341#[proc_macro_derive(ToPrompt, attributes(prompt))]
384pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
385 let input = parse_macro_input!(input as DeriveInput);
386
387 let found_crate =
388 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
389 let crate_path = match found_crate {
390 FoundCrate::Itself => {
391 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
393 quote!(::#ident)
394 }
395 FoundCrate::Name(name) => {
396 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
397 quote!(::#ident)
398 }
399 };
400
401 match &input.data {
403 Data::Enum(data_enum) => {
404 let enum_name = &input.ident;
406 let enum_docs = extract_doc_comments(&input.attrs);
407
408 let mut prompt_lines = Vec::new();
409
410 if !enum_docs.is_empty() {
412 prompt_lines.push(format!("{}: {}", enum_name, enum_docs));
413 } else {
414 prompt_lines.push(format!("{}:", enum_name));
415 }
416 prompt_lines.push(String::new()); prompt_lines.push("Possible values:".to_string());
418
419 for variant in &data_enum.variants {
421 let variant_name = &variant.ident;
422
423 match parse_prompt_attribute(&variant.attrs) {
425 PromptAttribute::Skip => {
426 continue;
428 }
429 PromptAttribute::Description(desc) => {
430 prompt_lines.push(format!("- {}: {}", variant_name, desc));
432 }
433 PromptAttribute::None => {
434 let variant_docs = extract_doc_comments(&variant.attrs);
436 if !variant_docs.is_empty() {
437 prompt_lines.push(format!("- {}: {}", variant_name, variant_docs));
438 } else {
439 prompt_lines.push(format!("- {}", variant_name));
440 }
441 }
442 }
443 }
444
445 let prompt_string = prompt_lines.join("\n");
446 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
447
448 let mut match_arms = Vec::new();
450 for variant in &data_enum.variants {
451 let variant_name = &variant.ident;
452
453 match parse_prompt_attribute(&variant.attrs) {
455 PromptAttribute::Skip => {
456 match_arms.push(quote! {
458 Self::#variant_name => stringify!(#variant_name).to_string()
459 });
460 }
461 PromptAttribute::Description(desc) => {
462 match_arms.push(quote! {
464 Self::#variant_name => format!("{}: {}", stringify!(#variant_name), #desc)
465 });
466 }
467 PromptAttribute::None => {
468 let variant_docs = extract_doc_comments(&variant.attrs);
470 if !variant_docs.is_empty() {
471 match_arms.push(quote! {
472 Self::#variant_name => format!("{}: {}", stringify!(#variant_name), #variant_docs)
473 });
474 } else {
475 match_arms.push(quote! {
476 Self::#variant_name => stringify!(#variant_name).to_string()
477 });
478 }
479 }
480 }
481 }
482
483 let to_prompt_impl = if match_arms.is_empty() {
484 quote! {
486 fn to_prompt(&self) -> String {
487 match *self {}
488 }
489 }
490 } else {
491 quote! {
492 fn to_prompt(&self) -> String {
493 match self {
494 #(#match_arms),*
495 }
496 }
497 }
498 };
499
500 let expanded = quote! {
501 impl #impl_generics #crate_path::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
502 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
503 vec![#crate_path::prompt::PromptPart::Text(self.to_prompt())]
504 }
505
506 #to_prompt_impl
507
508 fn prompt_schema() -> String {
509 #prompt_string.to_string()
510 }
511 }
512 };
513
514 TokenStream::from(expanded)
515 }
516 Data::Struct(data_struct) => {
517 let mut template_attr = None;
519 let mut template_file_attr = None;
520 let mut mode_attr = None;
521 let mut validate_attr = false;
522
523 for attr in &input.attrs {
524 if attr.path().is_ident("prompt") {
525 if let Ok(metas) =
527 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
528 {
529 for meta in metas {
530 match meta {
531 Meta::NameValue(nv) if nv.path.is_ident("template") => {
532 if let syn::Expr::Lit(expr_lit) = nv.value
533 && let syn::Lit::Str(lit_str) = expr_lit.lit
534 {
535 template_attr = Some(lit_str.value());
536 }
537 }
538 Meta::NameValue(nv) if nv.path.is_ident("template_file") => {
539 if let syn::Expr::Lit(expr_lit) = nv.value
540 && let syn::Lit::Str(lit_str) = expr_lit.lit
541 {
542 template_file_attr = Some(lit_str.value());
543 }
544 }
545 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
546 if let syn::Expr::Lit(expr_lit) = nv.value
547 && let syn::Lit::Str(lit_str) = expr_lit.lit
548 {
549 mode_attr = Some(lit_str.value());
550 }
551 }
552 Meta::NameValue(nv) if nv.path.is_ident("validate") => {
553 if let syn::Expr::Lit(expr_lit) = nv.value
554 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
555 {
556 validate_attr = lit_bool.value();
557 }
558 }
559 _ => {}
560 }
561 }
562 }
563 }
564 }
565
566 if template_attr.is_some() && template_file_attr.is_some() {
568 return syn::Error::new(
569 input.ident.span(),
570 "The `template` and `template_file` attributes are mutually exclusive. Please use only one.",
571 ).to_compile_error().into();
572 }
573
574 let template_str = if let Some(file_path) = template_file_attr {
576 let mut full_path = None;
580
581 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
583 let is_trybuild = manifest_dir.contains("target/tests/trybuild");
585
586 if !is_trybuild {
587 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
589 if candidate.exists() {
590 full_path = Some(candidate);
591 }
592 } else {
593 if let Some(target_pos) = manifest_dir.find("/target/tests/trybuild") {
599 let workspace_root = &manifest_dir[..target_pos];
600 let original_macros_dir = std::path::Path::new(workspace_root)
602 .join("crates")
603 .join("llm-toolkit-macros");
604
605 let candidate = original_macros_dir.join(&file_path);
606 if candidate.exists() {
607 full_path = Some(candidate);
608 }
609 }
610 }
611 }
612
613 if full_path.is_none() {
615 let candidate = std::path::Path::new(&file_path).to_path_buf();
616 if candidate.exists() {
617 full_path = Some(candidate);
618 }
619 }
620
621 if full_path.is_none()
624 && let Ok(current_dir) = std::env::current_dir()
625 {
626 let mut search_dir = current_dir.as_path();
627 for _ in 0..10 {
629 let macros_dir = search_dir.join("crates/llm-toolkit-macros");
631 if macros_dir.exists() {
632 let candidate = macros_dir.join(&file_path);
633 if candidate.exists() {
634 full_path = Some(candidate);
635 break;
636 }
637 }
638 let candidate = search_dir.join(&file_path);
640 if candidate.exists() {
641 full_path = Some(candidate);
642 break;
643 }
644 if let Some(parent) = search_dir.parent() {
645 search_dir = parent;
646 } else {
647 break;
648 }
649 }
650 }
651
652 let final_path =
654 full_path.unwrap_or_else(|| std::path::Path::new(&file_path).to_path_buf());
655
656 match std::fs::read_to_string(&final_path) {
658 Ok(content) => Some(content),
659 Err(e) => {
660 return syn::Error::new(
661 input.ident.span(),
662 format!(
663 "Failed to read template file '{}': {}",
664 final_path.display(),
665 e
666 ),
667 )
668 .to_compile_error()
669 .into();
670 }
671 }
672 } else {
673 template_attr
674 };
675
676 if validate_attr && let Some(template) = &template_str {
678 let mut env = minijinja::Environment::new();
680 if let Err(e) = env.add_template("validation", template) {
681 let warning_msg =
683 format!("Template validation warning: Invalid Jinja syntax - {}", e);
684 let warning_ident = syn::Ident::new(
685 "TEMPLATE_VALIDATION_WARNING",
686 proc_macro2::Span::call_site(),
687 );
688 let _warning_tokens = quote! {
689 #[deprecated(note = #warning_msg)]
690 const #warning_ident: () = ();
691 let _ = #warning_ident;
692 };
693 eprintln!("cargo:warning={}", warning_msg);
695 }
696
697 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
699 &fields.named
700 } else {
701 panic!("Template validation is only supported for structs with named fields.");
702 };
703
704 let field_names: std::collections::HashSet<String> = fields
705 .iter()
706 .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
707 .collect();
708
709 let placeholders = parse_template_placeholders_with_mode(template);
711
712 for (placeholder_name, _mode) in &placeholders {
713 if placeholder_name != "self" && !field_names.contains(placeholder_name) {
714 let warning_msg = format!(
715 "Template validation warning: Variable '{}' used in template but not found in struct fields",
716 placeholder_name
717 );
718 eprintln!("cargo:warning={}", warning_msg);
719 }
720 }
721 }
722
723 let name = input.ident;
724 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
725
726 let struct_docs = extract_doc_comments(&input.attrs);
728
729 let is_mode_based =
731 mode_attr.is_some() || (template_str.is_none() && struct_docs.contains("mode"));
732
733 let expanded = if is_mode_based || mode_attr.is_some() {
734 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
736 &fields.named
737 } else {
738 panic!(
739 "Mode-based prompt generation is only supported for structs with named fields."
740 );
741 };
742
743 let struct_name_str = name.to_string();
744
745 let has_default = input.attrs.iter().any(|attr| {
747 if attr.path().is_ident("derive")
748 && let Ok(meta_list) = attr.meta.require_list()
749 {
750 let tokens_str = meta_list.tokens.to_string();
751 tokens_str.contains("Default")
752 } else {
753 false
754 }
755 });
756
757 let schema_parts =
759 generate_schema_only_parts(&struct_name_str, &struct_docs, fields, &crate_path);
760
761 let example_parts = generate_example_only_parts(fields, has_default, &crate_path);
763
764 quote! {
765 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
766 fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<#crate_path::prompt::PromptPart> {
767 match mode {
768 "schema_only" => #schema_parts,
769 "example_only" => #example_parts,
770 "full" | _ => {
771 let mut parts = Vec::new();
773
774 let schema_parts = #schema_parts;
776 parts.extend(schema_parts);
777
778 parts.push(#crate_path::prompt::PromptPart::Text("\n### Example".to_string()));
780 parts.push(#crate_path::prompt::PromptPart::Text(
781 format!("Here is an example of a valid `{}` object:", #struct_name_str)
782 ));
783
784 let example_parts = #example_parts;
786 parts.extend(example_parts);
787
788 parts
789 }
790 }
791 }
792
793 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
794 self.to_prompt_parts_with_mode("full")
795 }
796
797 fn to_prompt(&self) -> String {
798 self.to_prompt_parts()
799 .into_iter()
800 .filter_map(|part| match part {
801 #crate_path::prompt::PromptPart::Text(text) => Some(text),
802 _ => None,
803 })
804 .collect::<Vec<_>>()
805 .join("\n")
806 }
807
808 fn prompt_schema() -> String {
809 let schema_parts = #schema_parts;
810 schema_parts
811 .into_iter()
812 .filter_map(|part| match part {
813 #crate_path::prompt::PromptPart::Text(text) => Some(text),
814 _ => None,
815 })
816 .collect::<Vec<_>>()
817 .join("\n")
818 }
819 }
820 }
821 } else if let Some(template) = template_str {
822 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
825 &fields.named
826 } else {
827 panic!(
828 "Template prompt generation is only supported for structs with named fields."
829 );
830 };
831
832 let placeholders = parse_template_placeholders_with_mode(&template);
834 let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
836 mode.is_some()
837 && fields
838 .iter()
839 .any(|f| f.ident.as_ref().unwrap() == field_name)
840 });
841
842 let mut image_field_parts = Vec::new();
843 for f in fields.iter() {
844 let field_name = f.ident.as_ref().unwrap();
845 let attrs = parse_field_prompt_attrs(&f.attrs);
846
847 if attrs.image {
848 image_field_parts.push(quote! {
850 parts.extend(self.#field_name.to_prompt_parts());
851 });
852 }
853 }
854
855 if has_mode_syntax {
857 let mut context_fields = Vec::new();
859 let mut modified_template = template.clone();
860
861 for (field_name, mode_opt) in &placeholders {
863 if let Some(mode) = mode_opt {
864 let unique_key = format!("{}__{}", field_name, mode);
866
867 let pattern = format!("{{{{ {}:{} }}}}", field_name, mode);
869 let replacement = format!("{{{{ {} }}}}", unique_key);
870 modified_template = modified_template.replace(&pattern, &replacement);
871
872 let field_ident =
874 syn::Ident::new(field_name, proc_macro2::Span::call_site());
875
876 context_fields.push(quote! {
878 context.insert(
879 #unique_key.to_string(),
880 minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
881 );
882 });
883 }
884 }
885
886 for field in fields.iter() {
888 let field_name = field.ident.as_ref().unwrap();
889 let field_name_str = field_name.to_string();
890
891 let has_mode_entry = placeholders
893 .iter()
894 .any(|(name, mode)| name == &field_name_str && mode.is_some());
895
896 if !has_mode_entry {
897 let is_primitive = match &field.ty {
900 syn::Type::Path(type_path) => {
901 if let Some(segment) = type_path.path.segments.last() {
902 let type_name = segment.ident.to_string();
903 matches!(
904 type_name.as_str(),
905 "String"
906 | "str"
907 | "i8"
908 | "i16"
909 | "i32"
910 | "i64"
911 | "i128"
912 | "isize"
913 | "u8"
914 | "u16"
915 | "u32"
916 | "u64"
917 | "u128"
918 | "usize"
919 | "f32"
920 | "f64"
921 | "bool"
922 | "char"
923 )
924 } else {
925 false
926 }
927 }
928 _ => false,
929 };
930
931 if is_primitive {
932 context_fields.push(quote! {
933 context.insert(
934 #field_name_str.to_string(),
935 minijinja::Value::from_serialize(&self.#field_name)
936 );
937 });
938 } else {
939 context_fields.push(quote! {
941 context.insert(
942 #field_name_str.to_string(),
943 minijinja::Value::from(self.#field_name.to_prompt())
944 );
945 });
946 }
947 }
948 }
949
950 quote! {
951 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
952 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
953 let mut parts = Vec::new();
954
955 #(#image_field_parts)*
957
958 let text = {
960 let mut env = minijinja::Environment::new();
961 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
962 panic!("Failed to parse template: {}", e)
963 });
964
965 let tmpl = env.get_template("prompt").unwrap();
966
967 let mut context = std::collections::HashMap::new();
968 #(#context_fields)*
969
970 tmpl.render(context).unwrap_or_else(|e| {
971 format!("Failed to render prompt: {}", e)
972 })
973 };
974
975 if !text.is_empty() {
976 parts.push(#crate_path::prompt::PromptPart::Text(text));
977 }
978
979 parts
980 }
981
982 fn to_prompt(&self) -> String {
983 let mut env = minijinja::Environment::new();
985 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
986 panic!("Failed to parse template: {}", e)
987 });
988
989 let tmpl = env.get_template("prompt").unwrap();
990
991 let mut context = std::collections::HashMap::new();
992 #(#context_fields)*
993
994 tmpl.render(context).unwrap_or_else(|e| {
995 format!("Failed to render prompt: {}", e)
996 })
997 }
998
999 fn prompt_schema() -> String {
1000 String::new() }
1002 }
1003 }
1004 } else {
1005 quote! {
1007 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1008 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1009 let mut parts = Vec::new();
1010
1011 #(#image_field_parts)*
1013
1014 let text = #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1016 format!("Failed to render prompt: {}", e)
1017 });
1018 if !text.is_empty() {
1019 parts.push(#crate_path::prompt::PromptPart::Text(text));
1020 }
1021
1022 parts
1023 }
1024
1025 fn to_prompt(&self) -> String {
1026 #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1027 format!("Failed to render prompt: {}", e)
1028 })
1029 }
1030
1031 fn prompt_schema() -> String {
1032 String::new() }
1034 }
1035 }
1036 }
1037 } else {
1038 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1041 &fields.named
1042 } else {
1043 panic!(
1044 "Default prompt generation is only supported for structs with named fields."
1045 );
1046 };
1047
1048 let mut text_field_parts = Vec::new();
1050 let mut image_field_parts = Vec::new();
1051
1052 for f in fields.iter() {
1053 let field_name = f.ident.as_ref().unwrap();
1054 let attrs = parse_field_prompt_attrs(&f.attrs);
1055
1056 if attrs.skip {
1058 continue;
1059 }
1060
1061 if attrs.image {
1062 image_field_parts.push(quote! {
1064 parts.extend(self.#field_name.to_prompt_parts());
1065 });
1066 } else {
1067 let key = if let Some(rename) = attrs.rename {
1073 rename
1074 } else {
1075 let doc_comment = extract_doc_comments(&f.attrs);
1076 if !doc_comment.is_empty() {
1077 doc_comment
1078 } else {
1079 field_name.to_string()
1080 }
1081 };
1082
1083 let value_expr = if let Some(format_with) = attrs.format_with {
1085 let func_path: syn::Path =
1087 syn::parse_str(&format_with).unwrap_or_else(|_| {
1088 panic!("Invalid function path: {}", format_with)
1089 });
1090 quote! { #func_path(&self.#field_name) }
1091 } else {
1092 quote! { self.#field_name.to_prompt() }
1093 };
1094
1095 text_field_parts.push(quote! {
1096 text_parts.push(format!("{}: {}", #key, #value_expr));
1097 });
1098 }
1099 }
1100
1101 quote! {
1103 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1104 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1105 let mut parts = Vec::new();
1106
1107 #(#image_field_parts)*
1109
1110 let mut text_parts = Vec::new();
1112 #(#text_field_parts)*
1113
1114 if !text_parts.is_empty() {
1115 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1116 }
1117
1118 parts
1119 }
1120
1121 fn to_prompt(&self) -> String {
1122 let mut text_parts = Vec::new();
1123 #(#text_field_parts)*
1124 text_parts.join("\n")
1125 }
1126
1127 fn prompt_schema() -> String {
1128 String::new() }
1130 }
1131 }
1132 };
1133
1134 TokenStream::from(expanded)
1135 }
1136 Data::Union(_) => {
1137 panic!("`#[derive(ToPrompt)]` is not supported for unions");
1138 }
1139 }
1140}
1141
1142#[derive(Debug, Clone)]
1144struct TargetInfo {
1145 name: String,
1146 template: Option<String>,
1147 field_configs: std::collections::HashMap<String, FieldTargetConfig>,
1148}
1149
1150#[derive(Debug, Clone, Default)]
1152struct FieldTargetConfig {
1153 skip: bool,
1154 rename: Option<String>,
1155 format_with: Option<String>,
1156 image: bool,
1157 include_only: bool, }
1159
1160fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
1162 let mut configs = Vec::new();
1163
1164 for attr in attrs {
1165 if attr.path().is_ident("prompt_for")
1166 && let Ok(meta_list) = attr.meta.require_list()
1167 {
1168 if meta_list.tokens.to_string() == "skip" {
1170 let config = FieldTargetConfig {
1172 skip: true,
1173 ..Default::default()
1174 };
1175 configs.push(("*".to_string(), config));
1176 } else if let Ok(metas) =
1177 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1178 {
1179 let mut target_name = None;
1180 let mut config = FieldTargetConfig::default();
1181
1182 for meta in metas {
1183 match meta {
1184 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1185 if let syn::Expr::Lit(syn::ExprLit {
1186 lit: syn::Lit::Str(lit_str),
1187 ..
1188 }) = nv.value
1189 {
1190 target_name = Some(lit_str.value());
1191 }
1192 }
1193 Meta::Path(path) if path.is_ident("skip") => {
1194 config.skip = true;
1195 }
1196 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
1197 if let syn::Expr::Lit(syn::ExprLit {
1198 lit: syn::Lit::Str(lit_str),
1199 ..
1200 }) = nv.value
1201 {
1202 config.rename = Some(lit_str.value());
1203 }
1204 }
1205 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
1206 if let syn::Expr::Lit(syn::ExprLit {
1207 lit: syn::Lit::Str(lit_str),
1208 ..
1209 }) = nv.value
1210 {
1211 config.format_with = Some(lit_str.value());
1212 }
1213 }
1214 Meta::Path(path) if path.is_ident("image") => {
1215 config.image = true;
1216 }
1217 _ => {}
1218 }
1219 }
1220
1221 if let Some(name) = target_name {
1222 config.include_only = true;
1223 configs.push((name, config));
1224 }
1225 }
1226 }
1227 }
1228
1229 configs
1230}
1231
1232fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
1234 let mut targets = Vec::new();
1235
1236 for attr in attrs {
1237 if attr.path().is_ident("prompt_for")
1238 && let Ok(meta_list) = attr.meta.require_list()
1239 && let Ok(metas) =
1240 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1241 {
1242 let mut target_name = None;
1243 let mut template = None;
1244
1245 for meta in metas {
1246 match meta {
1247 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1248 if let syn::Expr::Lit(syn::ExprLit {
1249 lit: syn::Lit::Str(lit_str),
1250 ..
1251 }) = nv.value
1252 {
1253 target_name = Some(lit_str.value());
1254 }
1255 }
1256 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1257 if let syn::Expr::Lit(syn::ExprLit {
1258 lit: syn::Lit::Str(lit_str),
1259 ..
1260 }) = nv.value
1261 {
1262 template = Some(lit_str.value());
1263 }
1264 }
1265 _ => {}
1266 }
1267 }
1268
1269 if let Some(name) = target_name {
1270 targets.push(TargetInfo {
1271 name,
1272 template,
1273 field_configs: std::collections::HashMap::new(),
1274 });
1275 }
1276 }
1277 }
1278
1279 targets
1280}
1281
1282#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
1283pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
1284 let input = parse_macro_input!(input as DeriveInput);
1285
1286 let found_crate =
1287 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1288 let crate_path = match found_crate {
1289 FoundCrate::Itself => {
1290 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1292 quote!(::#ident)
1293 }
1294 FoundCrate::Name(name) => {
1295 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1296 quote!(::#ident)
1297 }
1298 };
1299
1300 let data_struct = match &input.data {
1302 Data::Struct(data) => data,
1303 _ => {
1304 return syn::Error::new(
1305 input.ident.span(),
1306 "`#[derive(ToPromptSet)]` is only supported for structs",
1307 )
1308 .to_compile_error()
1309 .into();
1310 }
1311 };
1312
1313 let fields = match &data_struct.fields {
1314 syn::Fields::Named(fields) => &fields.named,
1315 _ => {
1316 return syn::Error::new(
1317 input.ident.span(),
1318 "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
1319 )
1320 .to_compile_error()
1321 .into();
1322 }
1323 };
1324
1325 let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
1327
1328 for field in fields.iter() {
1330 let field_name = field.ident.as_ref().unwrap().to_string();
1331 let field_configs = parse_prompt_for_attrs(&field.attrs);
1332
1333 for (target_name, config) in field_configs {
1334 if target_name == "*" {
1335 for target in &mut targets {
1337 target
1338 .field_configs
1339 .entry(field_name.clone())
1340 .or_insert_with(FieldTargetConfig::default)
1341 .skip = config.skip;
1342 }
1343 } else {
1344 let target_exists = targets.iter().any(|t| t.name == target_name);
1346 if !target_exists {
1347 targets.push(TargetInfo {
1349 name: target_name.clone(),
1350 template: None,
1351 field_configs: std::collections::HashMap::new(),
1352 });
1353 }
1354
1355 let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
1356
1357 target.field_configs.insert(field_name.clone(), config);
1358 }
1359 }
1360 }
1361
1362 let mut match_arms = Vec::new();
1364
1365 for target in &targets {
1366 let target_name = &target.name;
1367
1368 if let Some(template_str) = &target.template {
1369 let mut image_parts = Vec::new();
1371
1372 for field in fields.iter() {
1373 let field_name = field.ident.as_ref().unwrap();
1374 let field_name_str = field_name.to_string();
1375
1376 if let Some(config) = target.field_configs.get(&field_name_str)
1377 && config.image
1378 {
1379 image_parts.push(quote! {
1380 parts.extend(self.#field_name.to_prompt_parts());
1381 });
1382 }
1383 }
1384
1385 match_arms.push(quote! {
1386 #target_name => {
1387 let mut parts = Vec::new();
1388
1389 #(#image_parts)*
1390
1391 let text = #crate_path::prompt::render_prompt(#template_str, self)
1392 .map_err(|e| #crate_path::prompt::PromptSetError::RenderFailed {
1393 target: #target_name.to_string(),
1394 source: e,
1395 })?;
1396
1397 if !text.is_empty() {
1398 parts.push(#crate_path::prompt::PromptPart::Text(text));
1399 }
1400
1401 Ok(parts)
1402 }
1403 });
1404 } else {
1405 let mut text_field_parts = Vec::new();
1407 let mut image_field_parts = Vec::new();
1408
1409 for field in fields.iter() {
1410 let field_name = field.ident.as_ref().unwrap();
1411 let field_name_str = field_name.to_string();
1412
1413 let config = target.field_configs.get(&field_name_str);
1415
1416 if let Some(cfg) = config
1418 && cfg.skip
1419 {
1420 continue;
1421 }
1422
1423 let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
1427 let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
1428 .iter()
1429 .any(|(name, _)| name != "*");
1430
1431 if has_any_target_specific_config && !is_explicitly_for_this_target {
1432 continue;
1433 }
1434
1435 if let Some(cfg) = config {
1436 if cfg.image {
1437 image_field_parts.push(quote! {
1438 parts.extend(self.#field_name.to_prompt_parts());
1439 });
1440 } else {
1441 let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
1442
1443 let value_expr = if let Some(format_with) = &cfg.format_with {
1444 match syn::parse_str::<syn::Path>(format_with) {
1446 Ok(func_path) => quote! { #func_path(&self.#field_name) },
1447 Err(_) => {
1448 let error_msg = format!(
1450 "Invalid function path in format_with: '{}'",
1451 format_with
1452 );
1453 quote! {
1454 compile_error!(#error_msg);
1455 String::new()
1456 }
1457 }
1458 }
1459 } else {
1460 quote! { self.#field_name.to_prompt() }
1461 };
1462
1463 text_field_parts.push(quote! {
1464 text_parts.push(format!("{}: {}", #key, #value_expr));
1465 });
1466 }
1467 } else {
1468 text_field_parts.push(quote! {
1470 text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
1471 });
1472 }
1473 }
1474
1475 match_arms.push(quote! {
1476 #target_name => {
1477 let mut parts = Vec::new();
1478
1479 #(#image_field_parts)*
1480
1481 let mut text_parts = Vec::new();
1482 #(#text_field_parts)*
1483
1484 if !text_parts.is_empty() {
1485 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1486 }
1487
1488 Ok(parts)
1489 }
1490 });
1491 }
1492 }
1493
1494 let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
1496
1497 match_arms.push(quote! {
1499 _ => {
1500 let available = vec![#(#target_names.to_string()),*];
1501 Err(#crate_path::prompt::PromptSetError::TargetNotFound {
1502 target: target.to_string(),
1503 available,
1504 })
1505 }
1506 });
1507
1508 let struct_name = &input.ident;
1509 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1510
1511 let expanded = quote! {
1512 impl #impl_generics #crate_path::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
1513 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<#crate_path::prompt::PromptPart>, #crate_path::prompt::PromptSetError> {
1514 match target {
1515 #(#match_arms)*
1516 }
1517 }
1518 }
1519 };
1520
1521 TokenStream::from(expanded)
1522}
1523
1524struct TypeList {
1526 types: Punctuated<syn::Type, Token![,]>,
1527}
1528
1529impl Parse for TypeList {
1530 fn parse(input: ParseStream) -> syn::Result<Self> {
1531 Ok(TypeList {
1532 types: Punctuated::parse_terminated(input)?,
1533 })
1534 }
1535}
1536
1537#[proc_macro]
1561pub fn examples_section(input: TokenStream) -> TokenStream {
1562 let input = parse_macro_input!(input as TypeList);
1563
1564 let found_crate =
1565 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1566 let _crate_path = match found_crate {
1567 FoundCrate::Itself => quote!(crate),
1568 FoundCrate::Name(name) => {
1569 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1570 quote!(::#ident)
1571 }
1572 };
1573
1574 let mut type_sections = Vec::new();
1576
1577 for ty in input.types.iter() {
1578 let type_name_str = quote!(#ty).to_string();
1580
1581 type_sections.push(quote! {
1583 {
1584 let type_name = #type_name_str;
1585 let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
1586 format!("---\n#### `{}`\n{}", type_name, json_example)
1587 }
1588 });
1589 }
1590
1591 let expanded = quote! {
1593 {
1594 let mut sections = Vec::new();
1595 sections.push("---".to_string());
1596 sections.push("### Examples".to_string());
1597 sections.push("".to_string());
1598 sections.push("Here are examples of the data structures you should use.".to_string());
1599 sections.push("".to_string());
1600
1601 #(sections.push(#type_sections);)*
1602
1603 sections.push("---".to_string());
1604
1605 sections.join("\n")
1606 }
1607 };
1608
1609 TokenStream::from(expanded)
1610}
1611
1612fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
1614 for attr in attrs {
1615 if attr.path().is_ident("prompt_for")
1616 && let Ok(meta_list) = attr.meta.require_list()
1617 && let Ok(metas) =
1618 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1619 {
1620 let mut target_type = None;
1621 let mut template = None;
1622
1623 for meta in metas {
1624 match meta {
1625 Meta::NameValue(nv) if nv.path.is_ident("target") => {
1626 if let syn::Expr::Lit(syn::ExprLit {
1627 lit: syn::Lit::Str(lit_str),
1628 ..
1629 }) = nv.value
1630 {
1631 target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
1633 }
1634 }
1635 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1636 if let syn::Expr::Lit(syn::ExprLit {
1637 lit: syn::Lit::Str(lit_str),
1638 ..
1639 }) = nv.value
1640 {
1641 template = Some(lit_str.value());
1642 }
1643 }
1644 _ => {}
1645 }
1646 }
1647
1648 if let (Some(target), Some(tmpl)) = (target_type, template) {
1649 return (target, tmpl);
1650 }
1651 }
1652 }
1653
1654 panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
1655}
1656
1657#[proc_macro_attribute]
1691pub fn define_intent(_attr: TokenStream, item: TokenStream) -> TokenStream {
1692 let input = parse_macro_input!(item as DeriveInput);
1693
1694 let found_crate =
1695 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1696 let crate_path = match found_crate {
1697 FoundCrate::Itself => {
1698 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1700 quote!(::#ident)
1701 }
1702 FoundCrate::Name(name) => {
1703 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1704 quote!(::#ident)
1705 }
1706 };
1707
1708 let enum_data = match &input.data {
1710 Data::Enum(data) => data,
1711 _ => {
1712 return syn::Error::new(
1713 input.ident.span(),
1714 "`#[define_intent]` can only be applied to enums",
1715 )
1716 .to_compile_error()
1717 .into();
1718 }
1719 };
1720
1721 let mut prompt_template = None;
1723 let mut extractor_tag = None;
1724 let mut mode = None;
1725
1726 for attr in &input.attrs {
1727 if attr.path().is_ident("intent")
1728 && let Ok(metas) =
1729 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1730 {
1731 for meta in metas {
1732 match meta {
1733 Meta::NameValue(nv) if nv.path.is_ident("prompt") => {
1734 if let syn::Expr::Lit(syn::ExprLit {
1735 lit: syn::Lit::Str(lit_str),
1736 ..
1737 }) = nv.value
1738 {
1739 prompt_template = Some(lit_str.value());
1740 }
1741 }
1742 Meta::NameValue(nv) if nv.path.is_ident("extractor_tag") => {
1743 if let syn::Expr::Lit(syn::ExprLit {
1744 lit: syn::Lit::Str(lit_str),
1745 ..
1746 }) = nv.value
1747 {
1748 extractor_tag = Some(lit_str.value());
1749 }
1750 }
1751 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
1752 if let syn::Expr::Lit(syn::ExprLit {
1753 lit: syn::Lit::Str(lit_str),
1754 ..
1755 }) = nv.value
1756 {
1757 mode = Some(lit_str.value());
1758 }
1759 }
1760 _ => {}
1761 }
1762 }
1763 }
1764 }
1765
1766 let mode = mode.unwrap_or_else(|| "single".to_string());
1768
1769 if mode != "single" && mode != "multi_tag" {
1771 return syn::Error::new(
1772 input.ident.span(),
1773 "`mode` must be either \"single\" or \"multi_tag\"",
1774 )
1775 .to_compile_error()
1776 .into();
1777 }
1778
1779 let prompt_template = match prompt_template {
1781 Some(p) => p,
1782 None => {
1783 return syn::Error::new(
1784 input.ident.span(),
1785 "`#[intent(...)]` attribute must include `prompt = \"...\"`",
1786 )
1787 .to_compile_error()
1788 .into();
1789 }
1790 };
1791
1792 if mode == "multi_tag" {
1794 let enum_name = &input.ident;
1795 let actions_doc = generate_multi_tag_actions_doc(&enum_data.variants);
1796 return generate_multi_tag_output(
1797 &input,
1798 enum_name,
1799 enum_data,
1800 prompt_template,
1801 actions_doc,
1802 );
1803 }
1804
1805 let extractor_tag = match extractor_tag {
1807 Some(t) => t,
1808 None => {
1809 return syn::Error::new(
1810 input.ident.span(),
1811 "`#[intent(...)]` attribute must include `extractor_tag = \"...\"`",
1812 )
1813 .to_compile_error()
1814 .into();
1815 }
1816 };
1817
1818 let enum_name = &input.ident;
1820 let enum_docs = extract_doc_comments(&input.attrs);
1821
1822 let mut intents_doc_lines = Vec::new();
1823
1824 if !enum_docs.is_empty() {
1826 intents_doc_lines.push(format!("{}: {}", enum_name, enum_docs));
1827 } else {
1828 intents_doc_lines.push(format!("{}:", enum_name));
1829 }
1830 intents_doc_lines.push(String::new()); intents_doc_lines.push("Possible values:".to_string());
1832
1833 for variant in &enum_data.variants {
1835 let variant_name = &variant.ident;
1836 let variant_docs = extract_doc_comments(&variant.attrs);
1837
1838 if !variant_docs.is_empty() {
1839 intents_doc_lines.push(format!("- {}: {}", variant_name, variant_docs));
1840 } else {
1841 intents_doc_lines.push(format!("- {}", variant_name));
1842 }
1843 }
1844
1845 let intents_doc_str = intents_doc_lines.join("\n");
1846
1847 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
1849 let user_variables: Vec<String> = placeholders
1850 .iter()
1851 .filter_map(|(name, _)| {
1852 if name != "intents_doc" {
1853 Some(name.clone())
1854 } else {
1855 None
1856 }
1857 })
1858 .collect();
1859
1860 let enum_name_str = enum_name.to_string();
1862 let snake_case_name = to_snake_case(&enum_name_str);
1863 let function_name = syn::Ident::new(
1864 &format!("build_{}_prompt", snake_case_name),
1865 proc_macro2::Span::call_site(),
1866 );
1867
1868 let function_params: Vec<proc_macro2::TokenStream> = user_variables
1870 .iter()
1871 .map(|var| {
1872 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
1873 quote! { #ident: &str }
1874 })
1875 .collect();
1876
1877 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
1879 .iter()
1880 .map(|var| {
1881 let var_str = var.clone();
1882 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
1883 quote! {
1884 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
1885 }
1886 })
1887 .collect();
1888
1889 let converted_template = prompt_template.clone();
1891
1892 let extractor_name = syn::Ident::new(
1894 &format!("{}Extractor", enum_name),
1895 proc_macro2::Span::call_site(),
1896 );
1897
1898 let filtered_attrs: Vec<_> = input
1900 .attrs
1901 .iter()
1902 .filter(|attr| !attr.path().is_ident("intent"))
1903 .collect();
1904
1905 let vis = &input.vis;
1907 let generics = &input.generics;
1908 let variants = &enum_data.variants;
1909 let enum_output = quote! {
1910 #(#filtered_attrs)*
1911 #vis enum #enum_name #generics {
1912 #variants
1913 }
1914 };
1915
1916 let expanded = quote! {
1918 #enum_output
1920
1921 pub fn #function_name(#(#function_params),*) -> String {
1923 let mut env = minijinja::Environment::new();
1924 env.add_template("prompt", #converted_template)
1925 .expect("Failed to parse intent prompt template");
1926
1927 let tmpl = env.get_template("prompt").unwrap();
1928
1929 let mut __template_context = std::collections::HashMap::new();
1930
1931 __template_context.insert("intents_doc".to_string(), minijinja::Value::from(#intents_doc_str));
1933
1934 #(#context_insertions)*
1936
1937 tmpl.render(&__template_context)
1938 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
1939 }
1940
1941 pub struct #extractor_name;
1943
1944 impl #extractor_name {
1945 pub const EXTRACTOR_TAG: &'static str = #extractor_tag;
1946 }
1947
1948 impl #crate_path::intent::IntentExtractor<#enum_name> for #extractor_name {
1949 fn extract_intent(&self, response: &str) -> Result<#enum_name, #crate_path::intent::IntentExtractionError> {
1950 #crate_path::intent::extract_intent_from_response(response, Self::EXTRACTOR_TAG)
1952 }
1953 }
1954 };
1955
1956 TokenStream::from(expanded)
1957}
1958
1959fn to_snake_case(s: &str) -> String {
1961 let mut result = String::new();
1962 let mut prev_upper = false;
1963
1964 for (i, ch) in s.chars().enumerate() {
1965 if ch.is_uppercase() {
1966 if i > 0 && !prev_upper {
1967 result.push('_');
1968 }
1969 result.push(ch.to_lowercase().next().unwrap());
1970 prev_upper = true;
1971 } else {
1972 result.push(ch);
1973 prev_upper = false;
1974 }
1975 }
1976
1977 result
1978}
1979
1980#[derive(Debug, Default)]
1982struct ActionAttrs {
1983 tag: Option<String>,
1984}
1985
1986fn parse_action_attrs(attrs: &[syn::Attribute]) -> ActionAttrs {
1987 let mut result = ActionAttrs::default();
1988
1989 for attr in attrs {
1990 if attr.path().is_ident("action")
1991 && let Ok(meta_list) = attr.meta.require_list()
1992 && let Ok(metas) =
1993 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1994 {
1995 for meta in metas {
1996 if let Meta::NameValue(nv) = meta
1997 && nv.path.is_ident("tag")
1998 && let syn::Expr::Lit(syn::ExprLit {
1999 lit: syn::Lit::Str(lit_str),
2000 ..
2001 }) = nv.value
2002 {
2003 result.tag = Some(lit_str.value());
2004 }
2005 }
2006 }
2007 }
2008
2009 result
2010}
2011
2012#[derive(Debug, Default)]
2014struct FieldActionAttrs {
2015 is_attribute: bool,
2016 is_inner_text: bool,
2017}
2018
2019fn parse_field_action_attrs(attrs: &[syn::Attribute]) -> FieldActionAttrs {
2020 let mut result = FieldActionAttrs::default();
2021
2022 for attr in attrs {
2023 if attr.path().is_ident("action")
2024 && let Ok(meta_list) = attr.meta.require_list()
2025 {
2026 let tokens_str = meta_list.tokens.to_string();
2027 if tokens_str == "attribute" {
2028 result.is_attribute = true;
2029 } else if tokens_str == "inner_text" {
2030 result.is_inner_text = true;
2031 }
2032 }
2033 }
2034
2035 result
2036}
2037
2038fn generate_multi_tag_actions_doc(
2040 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2041) -> String {
2042 let mut doc_lines = Vec::new();
2043
2044 for variant in variants {
2045 let action_attrs = parse_action_attrs(&variant.attrs);
2046
2047 if let Some(tag) = action_attrs.tag {
2048 let variant_docs = extract_doc_comments(&variant.attrs);
2049
2050 match &variant.fields {
2051 syn::Fields::Unit => {
2052 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2054 }
2055 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
2056 doc_lines.push(format!("- `<{}>...</{}>`: {}", tag, tag, variant_docs));
2058 }
2059 syn::Fields::Named(fields) => {
2060 let mut attrs_str = Vec::new();
2062 let mut has_inner_text = false;
2063
2064 for field in &fields.named {
2065 let field_name = field.ident.as_ref().unwrap();
2066 let field_attrs = parse_field_action_attrs(&field.attrs);
2067
2068 if field_attrs.is_attribute {
2069 attrs_str.push(format!("{}=\"...\"", field_name));
2070 } else if field_attrs.is_inner_text {
2071 has_inner_text = true;
2072 }
2073 }
2074
2075 let attrs_part = if !attrs_str.is_empty() {
2076 format!(" {}", attrs_str.join(" "))
2077 } else {
2078 String::new()
2079 };
2080
2081 if has_inner_text {
2082 doc_lines.push(format!(
2083 "- `<{}{}>...</{}>`: {}",
2084 tag, attrs_part, tag, variant_docs
2085 ));
2086 } else if !attrs_str.is_empty() {
2087 doc_lines.push(format!("- `<{}{} />`: {}", tag, attrs_part, variant_docs));
2088 } else {
2089 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2090 }
2091
2092 for field in &fields.named {
2094 let field_name = field.ident.as_ref().unwrap();
2095 let field_attrs = parse_field_action_attrs(&field.attrs);
2096 let field_docs = extract_doc_comments(&field.attrs);
2097
2098 if field_attrs.is_attribute {
2099 doc_lines
2100 .push(format!(" - `{}` (attribute): {}", field_name, field_docs));
2101 } else if field_attrs.is_inner_text {
2102 doc_lines
2103 .push(format!(" - `{}` (inner_text): {}", field_name, field_docs));
2104 }
2105 }
2106 }
2107 _ => {
2108 }
2110 }
2111 }
2112 }
2113
2114 doc_lines.join("\n")
2115}
2116
2117fn generate_tags_regex(
2119 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2120) -> String {
2121 let mut tag_names = Vec::new();
2122
2123 for variant in variants {
2124 let action_attrs = parse_action_attrs(&variant.attrs);
2125 if let Some(tag) = action_attrs.tag {
2126 tag_names.push(tag);
2127 }
2128 }
2129
2130 if tag_names.is_empty() {
2131 return String::new();
2132 }
2133
2134 let tags_pattern = tag_names.join("|");
2135 format!(
2138 r"(?is)<(?:{})\b[^>]*/>|<(?:{})\b[^>]*>.*?</(?:{})>",
2139 tags_pattern, tags_pattern, tags_pattern
2140 )
2141}
2142
2143fn generate_multi_tag_output(
2145 input: &DeriveInput,
2146 enum_name: &syn::Ident,
2147 enum_data: &syn::DataEnum,
2148 prompt_template: String,
2149 actions_doc: String,
2150) -> TokenStream {
2151 let found_crate =
2152 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2153 let crate_path = match found_crate {
2154 FoundCrate::Itself => {
2155 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2157 quote!(::#ident)
2158 }
2159 FoundCrate::Name(name) => {
2160 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2161 quote!(::#ident)
2162 }
2163 };
2164
2165 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2167 let user_variables: Vec<String> = placeholders
2168 .iter()
2169 .filter_map(|(name, _)| {
2170 if name != "actions_doc" {
2171 Some(name.clone())
2172 } else {
2173 None
2174 }
2175 })
2176 .collect();
2177
2178 let enum_name_str = enum_name.to_string();
2180 let snake_case_name = to_snake_case(&enum_name_str);
2181 let function_name = syn::Ident::new(
2182 &format!("build_{}_prompt", snake_case_name),
2183 proc_macro2::Span::call_site(),
2184 );
2185
2186 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2188 .iter()
2189 .map(|var| {
2190 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2191 quote! { #ident: &str }
2192 })
2193 .collect();
2194
2195 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2197 .iter()
2198 .map(|var| {
2199 let var_str = var.clone();
2200 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2201 quote! {
2202 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2203 }
2204 })
2205 .collect();
2206
2207 let extractor_name = syn::Ident::new(
2209 &format!("{}Extractor", enum_name),
2210 proc_macro2::Span::call_site(),
2211 );
2212
2213 let filtered_attrs: Vec<_> = input
2215 .attrs
2216 .iter()
2217 .filter(|attr| !attr.path().is_ident("intent"))
2218 .collect();
2219
2220 let filtered_variants: Vec<proc_macro2::TokenStream> = enum_data
2222 .variants
2223 .iter()
2224 .map(|variant| {
2225 let variant_name = &variant.ident;
2226 let variant_attrs: Vec<_> = variant
2227 .attrs
2228 .iter()
2229 .filter(|attr| !attr.path().is_ident("action"))
2230 .collect();
2231 let fields = &variant.fields;
2232
2233 let filtered_fields = match fields {
2235 syn::Fields::Named(named_fields) => {
2236 let filtered: Vec<_> = named_fields
2237 .named
2238 .iter()
2239 .map(|field| {
2240 let field_name = &field.ident;
2241 let field_type = &field.ty;
2242 let field_vis = &field.vis;
2243 let filtered_attrs: Vec<_> = field
2244 .attrs
2245 .iter()
2246 .filter(|attr| !attr.path().is_ident("action"))
2247 .collect();
2248 quote! {
2249 #(#filtered_attrs)*
2250 #field_vis #field_name: #field_type
2251 }
2252 })
2253 .collect();
2254 quote! { { #(#filtered,)* } }
2255 }
2256 syn::Fields::Unnamed(unnamed_fields) => {
2257 let types: Vec<_> = unnamed_fields
2258 .unnamed
2259 .iter()
2260 .map(|field| {
2261 let field_type = &field.ty;
2262 quote! { #field_type }
2263 })
2264 .collect();
2265 quote! { (#(#types),*) }
2266 }
2267 syn::Fields::Unit => quote! {},
2268 };
2269
2270 quote! {
2271 #(#variant_attrs)*
2272 #variant_name #filtered_fields
2273 }
2274 })
2275 .collect();
2276
2277 let vis = &input.vis;
2278 let generics = &input.generics;
2279
2280 let parsing_arms = generate_parsing_arms(&enum_data.variants, enum_name);
2282
2283 let tags_regex = generate_tags_regex(&enum_data.variants);
2285
2286 let expanded = quote! {
2287 #(#filtered_attrs)*
2289 #vis enum #enum_name #generics {
2290 #(#filtered_variants),*
2291 }
2292
2293 pub fn #function_name(#(#function_params),*) -> String {
2295 let mut env = minijinja::Environment::new();
2296 env.add_template("prompt", #prompt_template)
2297 .expect("Failed to parse intent prompt template");
2298
2299 let tmpl = env.get_template("prompt").unwrap();
2300
2301 let mut __template_context = std::collections::HashMap::new();
2302
2303 __template_context.insert("actions_doc".to_string(), minijinja::Value::from(#actions_doc));
2305
2306 #(#context_insertions)*
2308
2309 tmpl.render(&__template_context)
2310 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2311 }
2312
2313 pub struct #extractor_name;
2315
2316 impl #extractor_name {
2317 fn parse_single_action(&self, text: &str) -> Option<#enum_name> {
2318 use ::quick_xml::events::Event;
2319 use ::quick_xml::Reader;
2320
2321 let mut actions = Vec::new();
2322 let mut reader = Reader::from_str(text);
2323 reader.config_mut().trim_text(true);
2324
2325 let mut buf = Vec::new();
2326
2327 loop {
2328 match reader.read_event_into(&mut buf) {
2329 Ok(Event::Start(e)) => {
2330 let owned_e = e.into_owned();
2331 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2332 let is_empty = false;
2333
2334 #parsing_arms
2335 }
2336 Ok(Event::Empty(e)) => {
2337 let owned_e = e.into_owned();
2338 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2339 let is_empty = true;
2340
2341 #parsing_arms
2342 }
2343 Ok(Event::Eof) => break,
2344 Err(_) => {
2345 break;
2347 }
2348 _ => {}
2349 }
2350 buf.clear();
2351 }
2352
2353 actions.into_iter().next()
2354 }
2355
2356 pub fn extract_actions(&self, text: &str) -> Result<Vec<#enum_name>, #crate_path::intent::IntentError> {
2357 use ::quick_xml::events::Event;
2358 use ::quick_xml::Reader;
2359
2360 let mut actions = Vec::new();
2361 let mut reader = Reader::from_str(text);
2362 reader.config_mut().trim_text(true);
2363
2364 let mut buf = Vec::new();
2365
2366 loop {
2367 match reader.read_event_into(&mut buf) {
2368 Ok(Event::Start(e)) => {
2369 let owned_e = e.into_owned();
2370 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2371 let is_empty = false;
2372
2373 #parsing_arms
2374 }
2375 Ok(Event::Empty(e)) => {
2376 let owned_e = e.into_owned();
2377 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2378 let is_empty = true;
2379
2380 #parsing_arms
2381 }
2382 Ok(Event::Eof) => break,
2383 Err(_) => {
2384 break;
2386 }
2387 _ => {}
2388 }
2389 buf.clear();
2390 }
2391
2392 Ok(actions)
2393 }
2394
2395 pub fn transform_actions<F>(&self, text: &str, mut transformer: F) -> String
2396 where
2397 F: FnMut(#enum_name) -> String,
2398 {
2399 use ::regex::Regex;
2400
2401 let regex_pattern = #tags_regex;
2402 if regex_pattern.is_empty() {
2403 return text.to_string();
2404 }
2405
2406 let re = Regex::new(®ex_pattern).unwrap_or_else(|e| {
2407 panic!("Failed to compile regex for action tags: {}", e);
2408 });
2409
2410 re.replace_all(text, |caps: &::regex::Captures| {
2411 let matched = caps.get(0).map(|m| m.as_str()).unwrap_or("");
2412
2413 if let Some(action) = self.parse_single_action(matched) {
2415 transformer(action)
2416 } else {
2417 matched.to_string()
2419 }
2420 }).to_string()
2421 }
2422
2423 pub fn strip_actions(&self, text: &str) -> String {
2424 self.transform_actions(text, |_| String::new())
2425 }
2426 }
2427 };
2428
2429 TokenStream::from(expanded)
2430}
2431
2432fn generate_parsing_arms(
2434 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2435 enum_name: &syn::Ident,
2436) -> proc_macro2::TokenStream {
2437 let mut arms = Vec::new();
2438
2439 for variant in variants {
2440 let variant_name = &variant.ident;
2441 let action_attrs = parse_action_attrs(&variant.attrs);
2442
2443 if let Some(tag) = action_attrs.tag {
2444 match &variant.fields {
2445 syn::Fields::Unit => {
2446 arms.push(quote! {
2448 if &tag_name == #tag {
2449 actions.push(#enum_name::#variant_name);
2450 }
2451 });
2452 }
2453 syn::Fields::Unnamed(_fields) => {
2454 arms.push(quote! {
2456 if &tag_name == #tag && !is_empty {
2457 match reader.read_text(owned_e.name()) {
2459 Ok(text) => {
2460 actions.push(#enum_name::#variant_name(text.to_string()));
2461 }
2462 Err(_) => {
2463 actions.push(#enum_name::#variant_name(String::new()));
2465 }
2466 }
2467 }
2468 });
2469 }
2470 syn::Fields::Named(fields) => {
2471 let mut field_names = Vec::new();
2473 let mut has_inner_text_field = None;
2474
2475 for field in &fields.named {
2476 let field_name = field.ident.as_ref().unwrap();
2477 let field_attrs = parse_field_action_attrs(&field.attrs);
2478
2479 if field_attrs.is_attribute {
2480 field_names.push(field_name.clone());
2481 } else if field_attrs.is_inner_text {
2482 has_inner_text_field = Some(field_name.clone());
2483 }
2484 }
2485
2486 if let Some(inner_text_field) = has_inner_text_field {
2487 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2490 quote! {
2491 let mut #field_name = String::new();
2492 for attr in owned_e.attributes() {
2493 if let Ok(attr) = attr {
2494 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2495 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2496 break;
2497 }
2498 }
2499 }
2500 }
2501 }).collect();
2502
2503 arms.push(quote! {
2504 if &tag_name == #tag {
2505 #(#attr_extractions)*
2506
2507 if is_empty {
2509 let #inner_text_field = String::new();
2510 actions.push(#enum_name::#variant_name {
2511 #(#field_names,)*
2512 #inner_text_field,
2513 });
2514 } else {
2515 match reader.read_text(owned_e.name()) {
2517 Ok(text) => {
2518 let #inner_text_field = text.to_string();
2519 actions.push(#enum_name::#variant_name {
2520 #(#field_names,)*
2521 #inner_text_field,
2522 });
2523 }
2524 Err(_) => {
2525 let #inner_text_field = String::new();
2527 actions.push(#enum_name::#variant_name {
2528 #(#field_names,)*
2529 #inner_text_field,
2530 });
2531 }
2532 }
2533 }
2534 }
2535 });
2536 } else {
2537 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2539 quote! {
2540 let mut #field_name = String::new();
2541 for attr in owned_e.attributes() {
2542 if let Ok(attr) = attr {
2543 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2544 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2545 break;
2546 }
2547 }
2548 }
2549 }
2550 }).collect();
2551
2552 arms.push(quote! {
2553 if &tag_name == #tag {
2554 #(#attr_extractions)*
2555 actions.push(#enum_name::#variant_name {
2556 #(#field_names),*
2557 });
2558 }
2559 });
2560 }
2561 }
2562 }
2563 }
2564 }
2565
2566 quote! {
2567 #(#arms)*
2568 }
2569}
2570
2571#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
2573pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
2574 let input = parse_macro_input!(input as DeriveInput);
2575
2576 let found_crate =
2577 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2578 let crate_path = match found_crate {
2579 FoundCrate::Itself => {
2580 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2582 quote!(::#ident)
2583 }
2584 FoundCrate::Name(name) => {
2585 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2586 quote!(::#ident)
2587 }
2588 };
2589
2590 let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
2592
2593 let struct_name = &input.ident;
2594 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2595
2596 let placeholders = parse_template_placeholders_with_mode(&template);
2598
2599 let mut converted_template = template.clone();
2601 let mut context_fields = Vec::new();
2602
2603 let fields = match &input.data {
2605 Data::Struct(data_struct) => match &data_struct.fields {
2606 syn::Fields::Named(fields) => &fields.named,
2607 _ => panic!("ToPromptFor is only supported for structs with named fields"),
2608 },
2609 _ => panic!("ToPromptFor is only supported for structs"),
2610 };
2611
2612 let has_mode_support = input.attrs.iter().any(|attr| {
2614 if attr.path().is_ident("prompt")
2615 && let Ok(metas) =
2616 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2617 {
2618 for meta in metas {
2619 if let Meta::NameValue(nv) = meta
2620 && nv.path.is_ident("mode")
2621 {
2622 return true;
2623 }
2624 }
2625 }
2626 false
2627 });
2628
2629 for (placeholder_name, mode_opt) in &placeholders {
2631 if placeholder_name == "self" {
2632 if let Some(specific_mode) = mode_opt {
2633 let unique_key = format!("self__{}", specific_mode);
2635
2636 let pattern = format!("{{{{ self:{} }}}}", specific_mode);
2638 let replacement = format!("{{{{ {} }}}}", unique_key);
2639 converted_template = converted_template.replace(&pattern, &replacement);
2640
2641 context_fields.push(quote! {
2643 context.insert(
2644 #unique_key.to_string(),
2645 minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
2646 );
2647 });
2648 } else {
2649 if has_mode_support {
2652 context_fields.push(quote! {
2654 context.insert(
2655 "self".to_string(),
2656 minijinja::Value::from(self.to_prompt_with_mode(mode))
2657 );
2658 });
2659 } else {
2660 context_fields.push(quote! {
2662 context.insert(
2663 "self".to_string(),
2664 minijinja::Value::from(self.to_prompt())
2665 );
2666 });
2667 }
2668 }
2669 } else {
2670 let field_exists = fields.iter().any(|f| {
2673 f.ident
2674 .as_ref()
2675 .is_some_and(|ident| ident == placeholder_name)
2676 });
2677
2678 if field_exists {
2679 let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
2680
2681 context_fields.push(quote! {
2685 context.insert(
2686 #placeholder_name.to_string(),
2687 minijinja::Value::from_serialize(&self.#field_ident)
2688 );
2689 });
2690 }
2691 }
2693 }
2694
2695 let expanded = quote! {
2696 impl #impl_generics #crate_path::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
2697 where
2698 #target_type: serde::Serialize,
2699 {
2700 fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
2701 let mut env = minijinja::Environment::new();
2703 env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
2704 panic!("Failed to parse template: {}", e)
2705 });
2706
2707 let tmpl = env.get_template("prompt").unwrap();
2708
2709 let mut context = std::collections::HashMap::new();
2711 context.insert(
2713 "self".to_string(),
2714 minijinja::Value::from_serialize(self)
2715 );
2716 context.insert(
2718 "target".to_string(),
2719 minijinja::Value::from_serialize(target)
2720 );
2721 #(#context_fields)*
2722
2723 tmpl.render(context).unwrap_or_else(|e| {
2725 format!("Failed to render prompt: {}", e)
2726 })
2727 }
2728 }
2729 };
2730
2731 TokenStream::from(expanded)
2732}
2733
2734struct AgentAttrs {
2740 expertise: Option<String>,
2741 output: Option<syn::Type>,
2742 backend: Option<String>,
2743 model: Option<String>,
2744}
2745
2746impl Parse for AgentAttrs {
2747 fn parse(input: ParseStream) -> syn::Result<Self> {
2748 let mut expertise = None;
2749 let mut output = None;
2750 let mut backend = None;
2751 let mut model = None;
2752
2753 let pairs = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
2754
2755 for meta in pairs {
2756 match meta {
2757 Meta::NameValue(nv) if nv.path.is_ident("expertise") => {
2758 if let syn::Expr::Lit(syn::ExprLit {
2759 lit: syn::Lit::Str(lit_str),
2760 ..
2761 }) = &nv.value
2762 {
2763 expertise = Some(lit_str.value());
2764 }
2765 }
2766 Meta::NameValue(nv) if nv.path.is_ident("output") => {
2767 if let syn::Expr::Lit(syn::ExprLit {
2768 lit: syn::Lit::Str(lit_str),
2769 ..
2770 }) = &nv.value
2771 {
2772 let ty: syn::Type = syn::parse_str(&lit_str.value())?;
2773 output = Some(ty);
2774 }
2775 }
2776 Meta::NameValue(nv) if nv.path.is_ident("backend") => {
2777 if let syn::Expr::Lit(syn::ExprLit {
2778 lit: syn::Lit::Str(lit_str),
2779 ..
2780 }) = &nv.value
2781 {
2782 backend = Some(lit_str.value());
2783 }
2784 }
2785 Meta::NameValue(nv) if nv.path.is_ident("model") => {
2786 if let syn::Expr::Lit(syn::ExprLit {
2787 lit: syn::Lit::Str(lit_str),
2788 ..
2789 }) = &nv.value
2790 {
2791 model = Some(lit_str.value());
2792 }
2793 }
2794 _ => {}
2795 }
2796 }
2797
2798 Ok(AgentAttrs {
2799 expertise,
2800 output,
2801 backend,
2802 model,
2803 })
2804 }
2805}
2806
2807fn parse_agent_attrs(attrs: &[syn::Attribute]) -> syn::Result<AgentAttrs> {
2809 for attr in attrs {
2810 if attr.path().is_ident("agent") {
2811 return attr.parse_args::<AgentAttrs>();
2812 }
2813 }
2814
2815 Ok(AgentAttrs {
2816 expertise: None,
2817 output: None,
2818 backend: None,
2819 model: None,
2820 })
2821}
2822
2823#[proc_macro_derive(Agent, attributes(agent))]
2832pub fn derive_agent(input: TokenStream) -> TokenStream {
2833 let input = parse_macro_input!(input as DeriveInput);
2834 let struct_name = &input.ident;
2835
2836 let agent_attrs = match parse_agent_attrs(&input.attrs) {
2838 Ok(attrs) => attrs,
2839 Err(e) => return e.to_compile_error().into(),
2840 };
2841
2842 let expertise = agent_attrs
2843 .expertise
2844 .unwrap_or_else(|| String::from("general AI assistant"));
2845 let output_type = agent_attrs
2846 .output
2847 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
2848 let backend = agent_attrs
2849 .backend
2850 .unwrap_or_else(|| String::from("claude"));
2851 let model = agent_attrs.model;
2852
2853 let found_crate =
2855 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2856 let crate_path = match found_crate {
2857 FoundCrate::Itself => {
2858 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2860 quote!(::#ident)
2861 }
2862 FoundCrate::Name(name) => {
2863 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2864 quote!(::#ident)
2865 }
2866 };
2867
2868 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2869
2870 let agent_init = match backend.as_str() {
2872 "gemini" => {
2873 if let Some(model_str) = model {
2874 quote! {
2875 use #crate_path::agent::impls::GeminiAgent;
2876 let agent = GeminiAgent::new().with_model_str(#model_str);
2877 }
2878 } else {
2879 quote! {
2880 use #crate_path::agent::impls::GeminiAgent;
2881 let agent = GeminiAgent::new();
2882 }
2883 }
2884 }
2885 "claude" => {
2886 if let Some(model_str) = model {
2887 quote! {
2888 use #crate_path::agent::impls::ClaudeCodeAgent;
2889 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
2890 }
2891 } else {
2892 quote! {
2893 use #crate_path::agent::impls::ClaudeCodeAgent;
2894 let agent = ClaudeCodeAgent::new();
2895 }
2896 }
2897 }
2898 _ => {
2899 if let Some(model_str) = model {
2901 quote! {
2902 use #crate_path::agent::impls::ClaudeCodeAgent;
2903 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
2904 }
2905 } else {
2906 quote! {
2907 use #crate_path::agent::impls::ClaudeCodeAgent;
2908 let agent = ClaudeCodeAgent::new();
2909 }
2910 }
2911 }
2912 };
2913
2914 let expanded = quote! {
2915 #[async_trait::async_trait]
2916 impl #impl_generics #crate_path::agent::Agent for #struct_name #ty_generics #where_clause {
2917 type Output = #output_type;
2918
2919 fn expertise(&self) -> &str {
2920 #expertise
2921 }
2922
2923 async fn execute(&self, intent: String) -> Result<Self::Output, #crate_path::agent::AgentError> {
2924 #agent_init
2926
2927 let response = agent.execute(intent).await?;
2929
2930 let json_str = #crate_path::extract_json(&response)
2932 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))?;
2933
2934 serde_json::from_str(&json_str)
2936 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))
2937 }
2938
2939 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
2940 #agent_init
2942 agent.is_available().await
2943 }
2944 }
2945 };
2946
2947 TokenStream::from(expanded)
2948}