1use proc_macro::TokenStream;
2use proc_macro_crate::{FoundCrate, crate_name};
3use quote::quote;
4use regex::Regex;
5use syn::{
6 Data, DeriveInput, Meta, Token,
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 punctuated::Punctuated,
10};
11
12fn parse_template_placeholders_with_mode(template: &str) -> Vec<(String, Option<String>)> {
15 let mut placeholders = Vec::new();
16 let mut seen_fields = std::collections::HashSet::new();
17
18 let mode_pattern = Regex::new(r"\{\{\s*(\w+)\s*:\s*(\w+)\s*\}\}").unwrap();
20 for cap in mode_pattern.captures_iter(template) {
21 let field_name = cap[1].to_string();
22 let mode = cap[2].to_string();
23 placeholders.push((field_name.clone(), Some(mode)));
24 seen_fields.insert(field_name);
25 }
26
27 let standard_pattern = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
29 for cap in standard_pattern.captures_iter(template) {
30 let field_name = cap[1].to_string();
31 if !seen_fields.contains(&field_name) {
33 placeholders.push((field_name, None));
34 }
35 }
36
37 placeholders
38}
39
40fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
42 attrs
43 .iter()
44 .filter_map(|attr| {
45 if attr.path().is_ident("doc")
46 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
47 && let syn::Expr::Lit(syn::ExprLit {
48 lit: syn::Lit::Str(lit_str),
49 ..
50 }) = &meta_name_value.value
51 {
52 return Some(lit_str.value());
53 }
54 None
55 })
56 .map(|s| s.trim().to_string())
57 .collect::<Vec<_>>()
58 .join(" ")
59}
60
61fn generate_example_only_parts(
63 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
64 has_default: bool,
65 crate_path: &proc_macro2::TokenStream,
66) -> proc_macro2::TokenStream {
67 let mut field_values = Vec::new();
68
69 for field in fields.iter() {
70 let field_name = field.ident.as_ref().unwrap();
71 let field_name_str = field_name.to_string();
72 let attrs = parse_field_prompt_attrs(&field.attrs);
73
74 if field_name_str == "__type" {
78 continue;
79 }
80
81 if attrs.skip {
83 continue;
84 }
85
86 if let Some(example) = attrs.example {
88 field_values.push(quote! {
90 json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
91 });
92 } else if has_default {
93 field_values.push(quote! {
95 let default_value = serde_json::to_value(&default_instance.#field_name)
96 .unwrap_or(serde_json::Value::Null);
97 json_obj.insert(#field_name_str.to_string(), default_value);
98 });
99 } else {
100 field_values.push(quote! {
102 let value = serde_json::to_value(&self.#field_name)
103 .unwrap_or(serde_json::Value::Null);
104 json_obj.insert(#field_name_str.to_string(), value);
105 });
106 }
107 }
108
109 if has_default {
110 quote! {
111 {
112 let default_instance = Self::default();
113 let mut json_obj = serde_json::Map::new();
114 #(#field_values)*
115 let json_value = serde_json::Value::Object(json_obj);
116 let json_str = serde_json::to_string_pretty(&json_value)
117 .unwrap_or_else(|_| "{}".to_string());
118 vec![#crate_path::prompt::PromptPart::Text(json_str)]
119 }
120 }
121 } else {
122 quote! {
123 {
124 let mut json_obj = serde_json::Map::new();
125 #(#field_values)*
126 let json_value = serde_json::Value::Object(json_obj);
127 let json_str = serde_json::to_string_pretty(&json_value)
128 .unwrap_or_else(|_| "{}".to_string());
129 vec![#crate_path::prompt::PromptPart::Text(json_str)]
130 }
131 }
132 }
133}
134
135fn generate_schema_only_parts(
137 struct_name: &str,
138 struct_docs: &str,
139 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
140 crate_path: &proc_macro2::TokenStream,
141 _has_type_marker: bool,
142) -> proc_macro2::TokenStream {
143 let mut field_schema_parts = vec![];
144
145 for (i, field) in fields.iter().enumerate() {
147 let field_name = field.ident.as_ref().unwrap();
148 let field_name_str = field_name.to_string();
149 let attrs = parse_field_prompt_attrs(&field.attrs);
150
151 if field_name_str == "__type" {
155 continue;
156 }
157
158 if attrs.skip {
160 continue;
161 }
162
163 let field_docs = extract_doc_comments(&field.attrs);
165
166 let (is_vec, inner_type) = extract_vec_inner_type(&field.ty);
168
169 let remaining_fields = fields
171 .iter()
172 .skip(i + 1)
173 .filter(|f| {
174 let attrs = parse_field_prompt_attrs(&f.attrs);
175 !attrs.skip
176 })
177 .count();
178 let comma = if remaining_fields > 0 { "," } else { "" };
179
180 if is_vec {
181 let comment = if !field_docs.is_empty() {
183 format!(", // {}", field_docs)
184 } else {
185 String::new()
186 };
187
188 field_schema_parts.push(quote! {
189 {
190 let inner_schema = <#inner_type as #crate_path::prompt::ToPrompt>::prompt_schema();
191 if inner_schema.is_empty() {
192 format!(" \"{}\": \"{}[]\"{}{}", #field_name_str, stringify!(#inner_type).to_lowercase(), #comment, #comma)
194 } else {
195 let inner_lines: Vec<&str> = inner_schema.lines()
197 .skip_while(|line| line.starts_with("###") || line.trim() == "{")
198 .take_while(|line| line.trim() != "}")
199 .collect();
200 let inner_content = inner_lines.join("\n");
201 format!(" \"{}\": [\n {{\n{}\n }}\n ]{}{}",
202 #field_name_str,
203 inner_content.lines()
204 .map(|line| format!(" {}", line))
205 .collect::<Vec<_>>()
206 .join("\n"),
207 #comment,
208 #comma
209 )
210 }
211 }
212 });
213 } else {
214 let field_type = &field.ty;
216 let is_primitive = is_primitive_type(field_type);
217
218 if !is_primitive {
219 let comment = if !field_docs.is_empty() {
221 format!(", // {}", field_docs)
222 } else {
223 String::new()
224 };
225
226 field_schema_parts.push(quote! {
227 {
228 let nested_schema = <#field_type as #crate_path::prompt::ToPrompt>::prompt_schema();
229 if nested_schema.is_empty() {
230 let type_str = stringify!(#field_type).to_lowercase();
232 format!(" \"{}\": \"{}\"{}{}", #field_name_str, type_str, #comment, #comma)
233 } else {
234 let nested_lines: Vec<&str> = nested_schema.lines()
236 .skip_while(|line| line.starts_with("###") || line.trim() == "{")
237 .take_while(|line| line.trim() != "}")
238 .collect();
239
240 if nested_lines.is_empty() {
241 let type_str = stringify!(#field_type).to_lowercase();
243 format!(" \"{}\": \"{}\"{}{}", #field_name_str, type_str, #comment, #comma)
244 } else {
245 let indented_content = nested_lines.iter()
247 .map(|line| format!(" {}", line))
248 .collect::<Vec<_>>()
249 .join("\n");
250 format!(" \"{}\": {{\n{}\n }}{}{}", #field_name_str, indented_content, #comment, #comma)
251 }
252 }
253 }
254 });
255 } else {
256 let type_str = format_type_for_schema(&field.ty);
258 let comment = if !field_docs.is_empty() {
259 format!(", // {}", field_docs)
260 } else {
261 String::new()
262 };
263
264 field_schema_parts.push(quote! {
265 format!(" \"{}\": \"{}\"{}{}", #field_name_str, #type_str, #comment, #comma)
266 });
267 }
268 }
269 }
270
271 let header = if !struct_docs.is_empty() {
273 format!("### Schema for `{}`\n{}", struct_name, struct_docs)
274 } else {
275 format!("### Schema for `{}`", struct_name)
276 };
277
278 quote! {
279 {
280 let mut lines = vec![#header.to_string(), "{".to_string()];
281 #(lines.push(#field_schema_parts);)*
282 lines.push("}".to_string());
283 vec![#crate_path::prompt::PromptPart::Text(lines.join("\n"))]
284 }
285 }
286}
287
288fn extract_vec_inner_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
290 if let syn::Type::Path(type_path) = ty
291 && let Some(last_segment) = type_path.path.segments.last()
292 && last_segment.ident == "Vec"
293 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
294 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
295 {
296 return (true, Some(inner_type));
297 }
298 (false, None)
299}
300
301fn is_primitive_type(ty: &syn::Type) -> bool {
303 if let syn::Type::Path(type_path) = ty
304 && let Some(last_segment) = type_path.path.segments.last()
305 {
306 let type_name = last_segment.ident.to_string();
307 matches!(
308 type_name.as_str(),
309 "String"
310 | "str"
311 | "i8"
312 | "i16"
313 | "i32"
314 | "i64"
315 | "i128"
316 | "isize"
317 | "u8"
318 | "u16"
319 | "u32"
320 | "u64"
321 | "u128"
322 | "usize"
323 | "f32"
324 | "f64"
325 | "bool"
326 | "Vec"
327 | "Option"
328 | "HashMap"
329 | "BTreeMap"
330 | "HashSet"
331 | "BTreeSet"
332 )
333 } else {
334 true
336 }
337}
338
339fn format_type_for_schema(ty: &syn::Type) -> String {
341 match ty {
343 syn::Type::Path(type_path) => {
344 let path = &type_path.path;
345 if let Some(last_segment) = path.segments.last() {
346 let type_name = last_segment.ident.to_string();
347
348 if type_name == "Option"
350 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
351 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
352 {
353 return format!("{} | null", format_type_for_schema(inner_type));
354 }
355
356 match type_name.as_str() {
358 "String" | "str" => "string".to_string(),
359 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
360 | "u64" | "u128" | "usize" => "number".to_string(),
361 "f32" | "f64" => "number".to_string(),
362 "bool" => "boolean".to_string(),
363 "Vec" => {
364 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
365 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
366 {
367 return format!("{}[]", format_type_for_schema(inner_type));
368 }
369 "array".to_string()
370 }
371 _ => type_name.to_lowercase(),
372 }
373 } else {
374 "unknown".to_string()
375 }
376 }
377 _ => "unknown".to_string(),
378 }
379}
380
381enum PromptAttribute {
383 Skip,
384 Description(String),
385 None,
386}
387
388fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
390 for attr in attrs {
391 if attr.path().is_ident("prompt") {
392 if let Ok(meta_list) = attr.meta.require_list() {
394 let tokens = &meta_list.tokens;
395 let tokens_str = tokens.to_string();
396 if tokens_str == "skip" {
397 return PromptAttribute::Skip;
398 }
399 }
400
401 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
403 return PromptAttribute::Description(lit_str.value());
404 }
405 }
406 }
407 PromptAttribute::None
408}
409
410#[derive(Debug, Default)]
412struct FieldPromptAttrs {
413 skip: bool,
414 rename: Option<String>,
415 format_with: Option<String>,
416 image: bool,
417 example: Option<String>,
418}
419
420fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
422 let mut result = FieldPromptAttrs::default();
423
424 for attr in attrs {
425 if attr.path().is_ident("prompt") {
426 if let Ok(meta_list) = attr.meta.require_list() {
428 if let Ok(metas) =
430 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
431 {
432 for meta in metas {
433 match meta {
434 Meta::Path(path) if path.is_ident("skip") => {
435 result.skip = true;
436 }
437 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
438 if let syn::Expr::Lit(syn::ExprLit {
439 lit: syn::Lit::Str(lit_str),
440 ..
441 }) = nv.value
442 {
443 result.rename = Some(lit_str.value());
444 }
445 }
446 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
447 if let syn::Expr::Lit(syn::ExprLit {
448 lit: syn::Lit::Str(lit_str),
449 ..
450 }) = nv.value
451 {
452 result.format_with = Some(lit_str.value());
453 }
454 }
455 Meta::Path(path) if path.is_ident("image") => {
456 result.image = true;
457 }
458 Meta::NameValue(nv) if nv.path.is_ident("example") => {
459 if let syn::Expr::Lit(syn::ExprLit {
460 lit: syn::Lit::Str(lit_str),
461 ..
462 }) = nv.value
463 {
464 result.example = Some(lit_str.value());
465 }
466 }
467 _ => {}
468 }
469 }
470 } else if meta_list.tokens.to_string() == "skip" {
471 result.skip = true;
473 } else if meta_list.tokens.to_string() == "image" {
474 result.image = true;
476 }
477 }
478 }
479 }
480
481 result
482}
483
484#[proc_macro_derive(ToPrompt, attributes(prompt))]
527pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
528 let input = parse_macro_input!(input as DeriveInput);
529
530 let found_crate =
531 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
532 let crate_path = match found_crate {
533 FoundCrate::Itself => {
534 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
536 quote!(::#ident)
537 }
538 FoundCrate::Name(name) => {
539 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
540 quote!(::#ident)
541 }
542 };
543
544 match &input.data {
546 Data::Enum(data_enum) => {
547 let enum_name = &input.ident;
549 let enum_docs = extract_doc_comments(&input.attrs);
550
551 let mut prompt_lines = Vec::new();
552
553 if !enum_docs.is_empty() {
555 prompt_lines.push(format!("{}: {}", enum_name, enum_docs));
556 } else {
557 prompt_lines.push(format!("{}:", enum_name));
558 }
559 prompt_lines.push(String::new()); prompt_lines.push("Possible values:".to_string());
561
562 for variant in &data_enum.variants {
564 let variant_name = &variant.ident;
565
566 match parse_prompt_attribute(&variant.attrs) {
568 PromptAttribute::Skip => {
569 continue;
571 }
572 PromptAttribute::Description(desc) => {
573 prompt_lines.push(format!("- {}: {}", variant_name, desc));
575 }
576 PromptAttribute::None => {
577 let variant_docs = extract_doc_comments(&variant.attrs);
579 if !variant_docs.is_empty() {
580 prompt_lines.push(format!("- {}: {}", variant_name, variant_docs));
581 } else {
582 prompt_lines.push(format!("- {}", variant_name));
583 }
584 }
585 }
586 }
587
588 let prompt_string = prompt_lines.join("\n");
589 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
590
591 let mut match_arms = Vec::new();
593 for variant in &data_enum.variants {
594 let variant_name = &variant.ident;
595
596 match parse_prompt_attribute(&variant.attrs) {
598 PromptAttribute::Skip => {
599 match_arms.push(quote! {
601 Self::#variant_name => stringify!(#variant_name).to_string()
602 });
603 }
604 PromptAttribute::Description(desc) => {
605 match_arms.push(quote! {
607 Self::#variant_name => format!("{}: {}", stringify!(#variant_name), #desc)
608 });
609 }
610 PromptAttribute::None => {
611 let variant_docs = extract_doc_comments(&variant.attrs);
613 if !variant_docs.is_empty() {
614 match_arms.push(quote! {
615 Self::#variant_name => format!("{}: {}", stringify!(#variant_name), #variant_docs)
616 });
617 } else {
618 match_arms.push(quote! {
619 Self::#variant_name => stringify!(#variant_name).to_string()
620 });
621 }
622 }
623 }
624 }
625
626 let to_prompt_impl = if match_arms.is_empty() {
627 quote! {
629 fn to_prompt(&self) -> String {
630 match *self {}
631 }
632 }
633 } else {
634 quote! {
635 fn to_prompt(&self) -> String {
636 match self {
637 #(#match_arms),*
638 }
639 }
640 }
641 };
642
643 let expanded = quote! {
644 impl #impl_generics #crate_path::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
645 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
646 vec![#crate_path::prompt::PromptPart::Text(self.to_prompt())]
647 }
648
649 #to_prompt_impl
650
651 fn prompt_schema() -> String {
652 #prompt_string.to_string()
653 }
654 }
655 };
656
657 TokenStream::from(expanded)
658 }
659 Data::Struct(data_struct) => {
660 let mut template_attr = None;
662 let mut template_file_attr = None;
663 let mut mode_attr = None;
664 let mut validate_attr = false;
665 let mut type_marker_attr = false;
666
667 for attr in &input.attrs {
668 if attr.path().is_ident("prompt") {
669 if let Ok(metas) =
671 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
672 {
673 for meta in metas {
674 match meta {
675 Meta::NameValue(nv) if nv.path.is_ident("template") => {
676 if let syn::Expr::Lit(expr_lit) = nv.value
677 && let syn::Lit::Str(lit_str) = expr_lit.lit
678 {
679 template_attr = Some(lit_str.value());
680 }
681 }
682 Meta::NameValue(nv) if nv.path.is_ident("template_file") => {
683 if let syn::Expr::Lit(expr_lit) = nv.value
684 && let syn::Lit::Str(lit_str) = expr_lit.lit
685 {
686 template_file_attr = Some(lit_str.value());
687 }
688 }
689 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
690 if let syn::Expr::Lit(expr_lit) = nv.value
691 && let syn::Lit::Str(lit_str) = expr_lit.lit
692 {
693 mode_attr = Some(lit_str.value());
694 }
695 }
696 Meta::NameValue(nv) if nv.path.is_ident("validate") => {
697 if let syn::Expr::Lit(expr_lit) = nv.value
698 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
699 {
700 validate_attr = lit_bool.value();
701 }
702 }
703 Meta::NameValue(nv) if nv.path.is_ident("type_marker") => {
704 if let syn::Expr::Lit(expr_lit) = nv.value
705 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
706 {
707 type_marker_attr = lit_bool.value();
708 }
709 }
710 Meta::Path(path) if path.is_ident("type_marker") => {
711 type_marker_attr = true;
713 }
714 _ => {}
715 }
716 }
717 }
718 }
719 }
720
721 if template_attr.is_some() && template_file_attr.is_some() {
723 return syn::Error::new(
724 input.ident.span(),
725 "The `template` and `template_file` attributes are mutually exclusive. Please use only one.",
726 ).to_compile_error().into();
727 }
728
729 let template_str = if let Some(file_path) = template_file_attr {
731 let mut full_path = None;
735
736 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
738 let is_trybuild = manifest_dir.contains("target/tests/trybuild");
740
741 if !is_trybuild {
742 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
744 if candidate.exists() {
745 full_path = Some(candidate);
746 }
747 } else {
748 if let Some(target_pos) = manifest_dir.find("/target/tests/trybuild") {
754 let workspace_root = &manifest_dir[..target_pos];
755 let original_macros_dir = std::path::Path::new(workspace_root)
757 .join("crates")
758 .join("llm-toolkit-macros");
759
760 let candidate = original_macros_dir.join(&file_path);
761 if candidate.exists() {
762 full_path = Some(candidate);
763 }
764 }
765 }
766 }
767
768 if full_path.is_none() {
770 let candidate = std::path::Path::new(&file_path).to_path_buf();
771 if candidate.exists() {
772 full_path = Some(candidate);
773 }
774 }
775
776 if full_path.is_none()
779 && let Ok(current_dir) = std::env::current_dir()
780 {
781 let mut search_dir = current_dir.as_path();
782 for _ in 0..10 {
784 let macros_dir = search_dir.join("crates/llm-toolkit-macros");
786 if macros_dir.exists() {
787 let candidate = macros_dir.join(&file_path);
788 if candidate.exists() {
789 full_path = Some(candidate);
790 break;
791 }
792 }
793 let candidate = search_dir.join(&file_path);
795 if candidate.exists() {
796 full_path = Some(candidate);
797 break;
798 }
799 if let Some(parent) = search_dir.parent() {
800 search_dir = parent;
801 } else {
802 break;
803 }
804 }
805 }
806
807 if full_path.is_none() {
809 let mut error_msg = format!(
811 "Template file '{}' not found at compile time.\n\nSearched in:",
812 file_path
813 );
814
815 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
816 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
817 error_msg.push_str(&format!("\n - {}", candidate.display()));
818 }
819
820 if let Ok(current_dir) = std::env::current_dir() {
821 let candidate = current_dir.join(&file_path);
822 error_msg.push_str(&format!("\n - {}", candidate.display()));
823 }
824
825 error_msg.push_str("\n\nPlease ensure:");
826 error_msg.push_str("\n 1. The template file exists");
827 error_msg.push_str("\n 2. The path is relative to CARGO_MANIFEST_DIR");
828 error_msg.push_str("\n 3. There are no typos in the path");
829
830 return syn::Error::new(input.ident.span(), error_msg)
831 .to_compile_error()
832 .into();
833 }
834
835 let final_path = full_path.unwrap();
836
837 match std::fs::read_to_string(&final_path) {
839 Ok(content) => Some(content),
840 Err(e) => {
841 return syn::Error::new(
842 input.ident.span(),
843 format!(
844 "Failed to read template file '{}': {}\n\nPath resolved to: {}",
845 file_path,
846 e,
847 final_path.display()
848 ),
849 )
850 .to_compile_error()
851 .into();
852 }
853 }
854 } else {
855 template_attr
856 };
857
858 if validate_attr && let Some(template) = &template_str {
860 let mut env = minijinja::Environment::new();
862 if let Err(e) = env.add_template("validation", template) {
863 let warning_msg =
865 format!("Template validation warning: Invalid Jinja syntax - {}", e);
866 let warning_ident = syn::Ident::new(
867 "TEMPLATE_VALIDATION_WARNING",
868 proc_macro2::Span::call_site(),
869 );
870 let _warning_tokens = quote! {
871 #[deprecated(note = #warning_msg)]
872 const #warning_ident: () = ();
873 let _ = #warning_ident;
874 };
875 eprintln!("cargo:warning={}", warning_msg);
877 }
878
879 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
881 &fields.named
882 } else {
883 panic!("Template validation is only supported for structs with named fields.");
884 };
885
886 let field_names: std::collections::HashSet<String> = fields
887 .iter()
888 .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
889 .collect();
890
891 let placeholders = parse_template_placeholders_with_mode(template);
893
894 for (placeholder_name, _mode) in &placeholders {
895 if placeholder_name != "self" && !field_names.contains(placeholder_name) {
896 let warning_msg = format!(
897 "Template validation warning: Variable '{}' used in template but not found in struct fields",
898 placeholder_name
899 );
900 eprintln!("cargo:warning={}", warning_msg);
901 }
902 }
903 }
904
905 let name = input.ident;
906 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
907
908 let struct_docs = extract_doc_comments(&input.attrs);
910
911 let is_mode_based =
913 mode_attr.is_some() || (template_str.is_none() && struct_docs.contains("mode"));
914
915 let expanded = if is_mode_based || mode_attr.is_some() {
916 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
918 &fields.named
919 } else {
920 panic!(
921 "Mode-based prompt generation is only supported for structs with named fields."
922 );
923 };
924
925 let struct_name_str = name.to_string();
926
927 let has_default = input.attrs.iter().any(|attr| {
929 if attr.path().is_ident("derive")
930 && let Ok(meta_list) = attr.meta.require_list()
931 {
932 let tokens_str = meta_list.tokens.to_string();
933 tokens_str.contains("Default")
934 } else {
935 false
936 }
937 });
938
939 let schema_parts = generate_schema_only_parts(
941 &struct_name_str,
942 &struct_docs,
943 fields,
944 &crate_path,
945 type_marker_attr,
946 );
947
948 let example_parts = generate_example_only_parts(fields, has_default, &crate_path);
950
951 quote! {
952 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
953 fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<#crate_path::prompt::PromptPart> {
954 match mode {
955 "schema_only" => #schema_parts,
956 "example_only" => #example_parts,
957 "full" | _ => {
958 let mut parts = Vec::new();
960
961 let schema_parts = #schema_parts;
963 parts.extend(schema_parts);
964
965 parts.push(#crate_path::prompt::PromptPart::Text("\n### Example".to_string()));
967 parts.push(#crate_path::prompt::PromptPart::Text(
968 format!("Here is an example of a valid `{}` object:", #struct_name_str)
969 ));
970
971 let example_parts = #example_parts;
973 parts.extend(example_parts);
974
975 parts
976 }
977 }
978 }
979
980 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
981 self.to_prompt_parts_with_mode("full")
982 }
983
984 fn to_prompt(&self) -> String {
985 self.to_prompt_parts()
986 .into_iter()
987 .filter_map(|part| match part {
988 #crate_path::prompt::PromptPart::Text(text) => Some(text),
989 _ => None,
990 })
991 .collect::<Vec<_>>()
992 .join("\n")
993 }
994
995 fn prompt_schema() -> String {
996 use std::sync::OnceLock;
997 static SCHEMA_CACHE: OnceLock<String> = OnceLock::new();
998
999 SCHEMA_CACHE.get_or_init(|| {
1000 let schema_parts = #schema_parts;
1001 schema_parts
1002 .into_iter()
1003 .filter_map(|part| match part {
1004 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1005 _ => None,
1006 })
1007 .collect::<Vec<_>>()
1008 .join("\n")
1009 }).clone()
1010 }
1011 }
1012 }
1013 } else if let Some(template) = template_str {
1014 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1017 &fields.named
1018 } else {
1019 panic!(
1020 "Template prompt generation is only supported for structs with named fields."
1021 );
1022 };
1023
1024 let placeholders = parse_template_placeholders_with_mode(&template);
1026 let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
1028 mode.is_some()
1029 && fields
1030 .iter()
1031 .any(|f| f.ident.as_ref().unwrap() == field_name)
1032 });
1033
1034 let mut image_field_parts = Vec::new();
1035 for f in fields.iter() {
1036 let field_name = f.ident.as_ref().unwrap();
1037 let attrs = parse_field_prompt_attrs(&f.attrs);
1038
1039 if attrs.image {
1040 image_field_parts.push(quote! {
1042 parts.extend(self.#field_name.to_prompt_parts());
1043 });
1044 }
1045 }
1046
1047 if has_mode_syntax {
1049 let mut context_fields = Vec::new();
1051 let mut modified_template = template.clone();
1052
1053 for (field_name, mode_opt) in &placeholders {
1055 if let Some(mode) = mode_opt {
1056 let unique_key = format!("{}__{}", field_name, mode);
1058
1059 let pattern = format!("{{{{ {}:{} }}}}", field_name, mode);
1061 let replacement = format!("{{{{ {} }}}}", unique_key);
1062 modified_template = modified_template.replace(&pattern, &replacement);
1063
1064 let field_ident =
1066 syn::Ident::new(field_name, proc_macro2::Span::call_site());
1067
1068 context_fields.push(quote! {
1070 context.insert(
1071 #unique_key.to_string(),
1072 minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
1073 );
1074 });
1075 }
1076 }
1077
1078 for field in fields.iter() {
1080 let field_name = field.ident.as_ref().unwrap();
1081 let field_name_str = field_name.to_string();
1082
1083 let has_mode_entry = placeholders
1085 .iter()
1086 .any(|(name, mode)| name == &field_name_str && mode.is_some());
1087
1088 if !has_mode_entry {
1089 let is_primitive = match &field.ty {
1092 syn::Type::Path(type_path) => {
1093 if let Some(segment) = type_path.path.segments.last() {
1094 let type_name = segment.ident.to_string();
1095 matches!(
1096 type_name.as_str(),
1097 "String"
1098 | "str"
1099 | "i8"
1100 | "i16"
1101 | "i32"
1102 | "i64"
1103 | "i128"
1104 | "isize"
1105 | "u8"
1106 | "u16"
1107 | "u32"
1108 | "u64"
1109 | "u128"
1110 | "usize"
1111 | "f32"
1112 | "f64"
1113 | "bool"
1114 | "char"
1115 )
1116 } else {
1117 false
1118 }
1119 }
1120 _ => false,
1121 };
1122
1123 if is_primitive {
1124 context_fields.push(quote! {
1125 context.insert(
1126 #field_name_str.to_string(),
1127 minijinja::Value::from_serialize(&self.#field_name)
1128 );
1129 });
1130 } else {
1131 context_fields.push(quote! {
1133 context.insert(
1134 #field_name_str.to_string(),
1135 minijinja::Value::from(self.#field_name.to_prompt())
1136 );
1137 });
1138 }
1139 }
1140 }
1141
1142 quote! {
1143 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1144 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1145 let mut parts = Vec::new();
1146
1147 #(#image_field_parts)*
1149
1150 let text = {
1152 let mut env = minijinja::Environment::new();
1153 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1154 panic!("Failed to parse template: {}", e)
1155 });
1156
1157 let tmpl = env.get_template("prompt").unwrap();
1158
1159 let mut context = std::collections::HashMap::new();
1160 #(#context_fields)*
1161
1162 tmpl.render(context).unwrap_or_else(|e| {
1163 format!("Failed to render prompt: {}", e)
1164 })
1165 };
1166
1167 if !text.is_empty() {
1168 parts.push(#crate_path::prompt::PromptPart::Text(text));
1169 }
1170
1171 parts
1172 }
1173
1174 fn to_prompt(&self) -> String {
1175 let mut env = minijinja::Environment::new();
1177 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1178 panic!("Failed to parse template: {}", e)
1179 });
1180
1181 let tmpl = env.get_template("prompt").unwrap();
1182
1183 let mut context = std::collections::HashMap::new();
1184 #(#context_fields)*
1185
1186 tmpl.render(context).unwrap_or_else(|e| {
1187 format!("Failed to render prompt: {}", e)
1188 })
1189 }
1190
1191 fn prompt_schema() -> String {
1192 String::new() }
1194 }
1195 }
1196 } else {
1197 quote! {
1199 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1200 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1201 let mut parts = Vec::new();
1202
1203 #(#image_field_parts)*
1205
1206 let text = #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1208 format!("Failed to render prompt: {}", e)
1209 });
1210 if !text.is_empty() {
1211 parts.push(#crate_path::prompt::PromptPart::Text(text));
1212 }
1213
1214 parts
1215 }
1216
1217 fn to_prompt(&self) -> String {
1218 #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1219 format!("Failed to render prompt: {}", e)
1220 })
1221 }
1222
1223 fn prompt_schema() -> String {
1224 String::new() }
1226 }
1227 }
1228 }
1229 } else {
1230 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1233 &fields.named
1234 } else {
1235 panic!(
1236 "Default prompt generation is only supported for structs with named fields."
1237 );
1238 };
1239
1240 let mut text_field_parts = Vec::new();
1242 let mut image_field_parts = Vec::new();
1243
1244 for f in fields.iter() {
1245 let field_name = f.ident.as_ref().unwrap();
1246 let attrs = parse_field_prompt_attrs(&f.attrs);
1247
1248 if attrs.skip {
1250 continue;
1251 }
1252
1253 if attrs.image {
1254 image_field_parts.push(quote! {
1256 parts.extend(self.#field_name.to_prompt_parts());
1257 });
1258 } else {
1259 let key = if let Some(rename) = attrs.rename {
1265 rename
1266 } else {
1267 let doc_comment = extract_doc_comments(&f.attrs);
1268 if !doc_comment.is_empty() {
1269 doc_comment
1270 } else {
1271 field_name.to_string()
1272 }
1273 };
1274
1275 let value_expr = if let Some(format_with) = attrs.format_with {
1277 let func_path: syn::Path =
1279 syn::parse_str(&format_with).unwrap_or_else(|_| {
1280 panic!("Invalid function path: {}", format_with)
1281 });
1282 quote! { #func_path(&self.#field_name) }
1283 } else {
1284 quote! { self.#field_name.to_prompt() }
1285 };
1286
1287 text_field_parts.push(quote! {
1288 text_parts.push(format!("{}: {}", #key, #value_expr));
1289 });
1290 }
1291 }
1292
1293 quote! {
1295 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1296 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1297 let mut parts = Vec::new();
1298
1299 #(#image_field_parts)*
1301
1302 let mut text_parts = Vec::new();
1304 #(#text_field_parts)*
1305
1306 if !text_parts.is_empty() {
1307 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1308 }
1309
1310 parts
1311 }
1312
1313 fn to_prompt(&self) -> String {
1314 let mut text_parts = Vec::new();
1315 #(#text_field_parts)*
1316 text_parts.join("\n")
1317 }
1318
1319 fn prompt_schema() -> String {
1320 String::new() }
1322 }
1323 }
1324 };
1325
1326 TokenStream::from(expanded)
1327 }
1328 Data::Union(_) => {
1329 panic!("`#[derive(ToPrompt)]` is not supported for unions");
1330 }
1331 }
1332}
1333
1334#[derive(Debug, Clone)]
1336struct TargetInfo {
1337 name: String,
1338 template: Option<String>,
1339 field_configs: std::collections::HashMap<String, FieldTargetConfig>,
1340}
1341
1342#[derive(Debug, Clone, Default)]
1344struct FieldTargetConfig {
1345 skip: bool,
1346 rename: Option<String>,
1347 format_with: Option<String>,
1348 image: bool,
1349 include_only: bool, }
1351
1352fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
1354 let mut configs = Vec::new();
1355
1356 for attr in attrs {
1357 if attr.path().is_ident("prompt_for")
1358 && let Ok(meta_list) = attr.meta.require_list()
1359 {
1360 if meta_list.tokens.to_string() == "skip" {
1362 let config = FieldTargetConfig {
1364 skip: true,
1365 ..Default::default()
1366 };
1367 configs.push(("*".to_string(), config));
1368 } else if let Ok(metas) =
1369 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1370 {
1371 let mut target_name = None;
1372 let mut config = FieldTargetConfig::default();
1373
1374 for meta in metas {
1375 match meta {
1376 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1377 if let syn::Expr::Lit(syn::ExprLit {
1378 lit: syn::Lit::Str(lit_str),
1379 ..
1380 }) = nv.value
1381 {
1382 target_name = Some(lit_str.value());
1383 }
1384 }
1385 Meta::Path(path) if path.is_ident("skip") => {
1386 config.skip = true;
1387 }
1388 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
1389 if let syn::Expr::Lit(syn::ExprLit {
1390 lit: syn::Lit::Str(lit_str),
1391 ..
1392 }) = nv.value
1393 {
1394 config.rename = Some(lit_str.value());
1395 }
1396 }
1397 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
1398 if let syn::Expr::Lit(syn::ExprLit {
1399 lit: syn::Lit::Str(lit_str),
1400 ..
1401 }) = nv.value
1402 {
1403 config.format_with = Some(lit_str.value());
1404 }
1405 }
1406 Meta::Path(path) if path.is_ident("image") => {
1407 config.image = true;
1408 }
1409 _ => {}
1410 }
1411 }
1412
1413 if let Some(name) = target_name {
1414 config.include_only = true;
1415 configs.push((name, config));
1416 }
1417 }
1418 }
1419 }
1420
1421 configs
1422}
1423
1424fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
1426 let mut targets = Vec::new();
1427
1428 for attr in attrs {
1429 if attr.path().is_ident("prompt_for")
1430 && let Ok(meta_list) = attr.meta.require_list()
1431 && let Ok(metas) =
1432 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1433 {
1434 let mut target_name = None;
1435 let mut template = None;
1436
1437 for meta in metas {
1438 match meta {
1439 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1440 if let syn::Expr::Lit(syn::ExprLit {
1441 lit: syn::Lit::Str(lit_str),
1442 ..
1443 }) = nv.value
1444 {
1445 target_name = Some(lit_str.value());
1446 }
1447 }
1448 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1449 if let syn::Expr::Lit(syn::ExprLit {
1450 lit: syn::Lit::Str(lit_str),
1451 ..
1452 }) = nv.value
1453 {
1454 template = Some(lit_str.value());
1455 }
1456 }
1457 _ => {}
1458 }
1459 }
1460
1461 if let Some(name) = target_name {
1462 targets.push(TargetInfo {
1463 name,
1464 template,
1465 field_configs: std::collections::HashMap::new(),
1466 });
1467 }
1468 }
1469 }
1470
1471 targets
1472}
1473
1474#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
1475pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
1476 let input = parse_macro_input!(input as DeriveInput);
1477
1478 let found_crate =
1479 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1480 let crate_path = match found_crate {
1481 FoundCrate::Itself => {
1482 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1484 quote!(::#ident)
1485 }
1486 FoundCrate::Name(name) => {
1487 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1488 quote!(::#ident)
1489 }
1490 };
1491
1492 let data_struct = match &input.data {
1494 Data::Struct(data) => data,
1495 _ => {
1496 return syn::Error::new(
1497 input.ident.span(),
1498 "`#[derive(ToPromptSet)]` is only supported for structs",
1499 )
1500 .to_compile_error()
1501 .into();
1502 }
1503 };
1504
1505 let fields = match &data_struct.fields {
1506 syn::Fields::Named(fields) => &fields.named,
1507 _ => {
1508 return syn::Error::new(
1509 input.ident.span(),
1510 "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
1511 )
1512 .to_compile_error()
1513 .into();
1514 }
1515 };
1516
1517 let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
1519
1520 for field in fields.iter() {
1522 let field_name = field.ident.as_ref().unwrap().to_string();
1523 let field_configs = parse_prompt_for_attrs(&field.attrs);
1524
1525 for (target_name, config) in field_configs {
1526 if target_name == "*" {
1527 for target in &mut targets {
1529 target
1530 .field_configs
1531 .entry(field_name.clone())
1532 .or_insert_with(FieldTargetConfig::default)
1533 .skip = config.skip;
1534 }
1535 } else {
1536 let target_exists = targets.iter().any(|t| t.name == target_name);
1538 if !target_exists {
1539 targets.push(TargetInfo {
1541 name: target_name.clone(),
1542 template: None,
1543 field_configs: std::collections::HashMap::new(),
1544 });
1545 }
1546
1547 let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
1548
1549 target.field_configs.insert(field_name.clone(), config);
1550 }
1551 }
1552 }
1553
1554 let mut match_arms = Vec::new();
1556
1557 for target in &targets {
1558 let target_name = &target.name;
1559
1560 if let Some(template_str) = &target.template {
1561 let mut image_parts = Vec::new();
1563
1564 for field in fields.iter() {
1565 let field_name = field.ident.as_ref().unwrap();
1566 let field_name_str = field_name.to_string();
1567
1568 if let Some(config) = target.field_configs.get(&field_name_str)
1569 && config.image
1570 {
1571 image_parts.push(quote! {
1572 parts.extend(self.#field_name.to_prompt_parts());
1573 });
1574 }
1575 }
1576
1577 match_arms.push(quote! {
1578 #target_name => {
1579 let mut parts = Vec::new();
1580
1581 #(#image_parts)*
1582
1583 let text = #crate_path::prompt::render_prompt(#template_str, self)
1584 .map_err(|e| #crate_path::prompt::PromptSetError::RenderFailed {
1585 target: #target_name.to_string(),
1586 source: e,
1587 })?;
1588
1589 if !text.is_empty() {
1590 parts.push(#crate_path::prompt::PromptPart::Text(text));
1591 }
1592
1593 Ok(parts)
1594 }
1595 });
1596 } else {
1597 let mut text_field_parts = Vec::new();
1599 let mut image_field_parts = Vec::new();
1600
1601 for field in fields.iter() {
1602 let field_name = field.ident.as_ref().unwrap();
1603 let field_name_str = field_name.to_string();
1604
1605 let config = target.field_configs.get(&field_name_str);
1607
1608 if let Some(cfg) = config
1610 && cfg.skip
1611 {
1612 continue;
1613 }
1614
1615 let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
1619 let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
1620 .iter()
1621 .any(|(name, _)| name != "*");
1622
1623 if has_any_target_specific_config && !is_explicitly_for_this_target {
1624 continue;
1625 }
1626
1627 if let Some(cfg) = config {
1628 if cfg.image {
1629 image_field_parts.push(quote! {
1630 parts.extend(self.#field_name.to_prompt_parts());
1631 });
1632 } else {
1633 let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
1634
1635 let value_expr = if let Some(format_with) = &cfg.format_with {
1636 match syn::parse_str::<syn::Path>(format_with) {
1638 Ok(func_path) => quote! { #func_path(&self.#field_name) },
1639 Err(_) => {
1640 let error_msg = format!(
1642 "Invalid function path in format_with: '{}'",
1643 format_with
1644 );
1645 quote! {
1646 compile_error!(#error_msg);
1647 String::new()
1648 }
1649 }
1650 }
1651 } else {
1652 quote! { self.#field_name.to_prompt() }
1653 };
1654
1655 text_field_parts.push(quote! {
1656 text_parts.push(format!("{}: {}", #key, #value_expr));
1657 });
1658 }
1659 } else {
1660 text_field_parts.push(quote! {
1662 text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
1663 });
1664 }
1665 }
1666
1667 match_arms.push(quote! {
1668 #target_name => {
1669 let mut parts = Vec::new();
1670
1671 #(#image_field_parts)*
1672
1673 let mut text_parts = Vec::new();
1674 #(#text_field_parts)*
1675
1676 if !text_parts.is_empty() {
1677 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1678 }
1679
1680 Ok(parts)
1681 }
1682 });
1683 }
1684 }
1685
1686 let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
1688
1689 match_arms.push(quote! {
1691 _ => {
1692 let available = vec![#(#target_names.to_string()),*];
1693 Err(#crate_path::prompt::PromptSetError::TargetNotFound {
1694 target: target.to_string(),
1695 available,
1696 })
1697 }
1698 });
1699
1700 let struct_name = &input.ident;
1701 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1702
1703 let expanded = quote! {
1704 impl #impl_generics #crate_path::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
1705 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<#crate_path::prompt::PromptPart>, #crate_path::prompt::PromptSetError> {
1706 match target {
1707 #(#match_arms)*
1708 }
1709 }
1710 }
1711 };
1712
1713 TokenStream::from(expanded)
1714}
1715
1716struct TypeList {
1718 types: Punctuated<syn::Type, Token![,]>,
1719}
1720
1721impl Parse for TypeList {
1722 fn parse(input: ParseStream) -> syn::Result<Self> {
1723 Ok(TypeList {
1724 types: Punctuated::parse_terminated(input)?,
1725 })
1726 }
1727}
1728
1729#[proc_macro]
1753pub fn examples_section(input: TokenStream) -> TokenStream {
1754 let input = parse_macro_input!(input as TypeList);
1755
1756 let found_crate =
1757 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1758 let _crate_path = match found_crate {
1759 FoundCrate::Itself => quote!(crate),
1760 FoundCrate::Name(name) => {
1761 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1762 quote!(::#ident)
1763 }
1764 };
1765
1766 let mut type_sections = Vec::new();
1768
1769 for ty in input.types.iter() {
1770 let type_name_str = quote!(#ty).to_string();
1772
1773 type_sections.push(quote! {
1775 {
1776 let type_name = #type_name_str;
1777 let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
1778 format!("---\n#### `{}`\n{}", type_name, json_example)
1779 }
1780 });
1781 }
1782
1783 let expanded = quote! {
1785 {
1786 let mut sections = Vec::new();
1787 sections.push("---".to_string());
1788 sections.push("### Examples".to_string());
1789 sections.push("".to_string());
1790 sections.push("Here are examples of the data structures you should use.".to_string());
1791 sections.push("".to_string());
1792
1793 #(sections.push(#type_sections);)*
1794
1795 sections.push("---".to_string());
1796
1797 sections.join("\n")
1798 }
1799 };
1800
1801 TokenStream::from(expanded)
1802}
1803
1804fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
1806 for attr in attrs {
1807 if attr.path().is_ident("prompt_for")
1808 && let Ok(meta_list) = attr.meta.require_list()
1809 && let Ok(metas) =
1810 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1811 {
1812 let mut target_type = None;
1813 let mut template = None;
1814
1815 for meta in metas {
1816 match meta {
1817 Meta::NameValue(nv) if nv.path.is_ident("target") => {
1818 if let syn::Expr::Lit(syn::ExprLit {
1819 lit: syn::Lit::Str(lit_str),
1820 ..
1821 }) = nv.value
1822 {
1823 target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
1825 }
1826 }
1827 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1828 if let syn::Expr::Lit(syn::ExprLit {
1829 lit: syn::Lit::Str(lit_str),
1830 ..
1831 }) = nv.value
1832 {
1833 template = Some(lit_str.value());
1834 }
1835 }
1836 _ => {}
1837 }
1838 }
1839
1840 if let (Some(target), Some(tmpl)) = (target_type, template) {
1841 return (target, tmpl);
1842 }
1843 }
1844 }
1845
1846 panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
1847}
1848
1849#[proc_macro_attribute]
1883pub fn define_intent(_attr: TokenStream, item: TokenStream) -> TokenStream {
1884 let input = parse_macro_input!(item as DeriveInput);
1885
1886 let found_crate =
1887 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1888 let crate_path = match found_crate {
1889 FoundCrate::Itself => {
1890 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1892 quote!(::#ident)
1893 }
1894 FoundCrate::Name(name) => {
1895 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1896 quote!(::#ident)
1897 }
1898 };
1899
1900 let enum_data = match &input.data {
1902 Data::Enum(data) => data,
1903 _ => {
1904 return syn::Error::new(
1905 input.ident.span(),
1906 "`#[define_intent]` can only be applied to enums",
1907 )
1908 .to_compile_error()
1909 .into();
1910 }
1911 };
1912
1913 let mut prompt_template = None;
1915 let mut extractor_tag = None;
1916 let mut mode = None;
1917
1918 for attr in &input.attrs {
1919 if attr.path().is_ident("intent")
1920 && let Ok(metas) =
1921 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1922 {
1923 for meta in metas {
1924 match meta {
1925 Meta::NameValue(nv) if nv.path.is_ident("prompt") => {
1926 if let syn::Expr::Lit(syn::ExprLit {
1927 lit: syn::Lit::Str(lit_str),
1928 ..
1929 }) = nv.value
1930 {
1931 prompt_template = Some(lit_str.value());
1932 }
1933 }
1934 Meta::NameValue(nv) if nv.path.is_ident("extractor_tag") => {
1935 if let syn::Expr::Lit(syn::ExprLit {
1936 lit: syn::Lit::Str(lit_str),
1937 ..
1938 }) = nv.value
1939 {
1940 extractor_tag = Some(lit_str.value());
1941 }
1942 }
1943 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
1944 if let syn::Expr::Lit(syn::ExprLit {
1945 lit: syn::Lit::Str(lit_str),
1946 ..
1947 }) = nv.value
1948 {
1949 mode = Some(lit_str.value());
1950 }
1951 }
1952 _ => {}
1953 }
1954 }
1955 }
1956 }
1957
1958 let mode = mode.unwrap_or_else(|| "single".to_string());
1960
1961 if mode != "single" && mode != "multi_tag" {
1963 return syn::Error::new(
1964 input.ident.span(),
1965 "`mode` must be either \"single\" or \"multi_tag\"",
1966 )
1967 .to_compile_error()
1968 .into();
1969 }
1970
1971 let prompt_template = match prompt_template {
1973 Some(p) => p,
1974 None => {
1975 return syn::Error::new(
1976 input.ident.span(),
1977 "`#[intent(...)]` attribute must include `prompt = \"...\"`",
1978 )
1979 .to_compile_error()
1980 .into();
1981 }
1982 };
1983
1984 if mode == "multi_tag" {
1986 let enum_name = &input.ident;
1987 let actions_doc = generate_multi_tag_actions_doc(&enum_data.variants);
1988 return generate_multi_tag_output(
1989 &input,
1990 enum_name,
1991 enum_data,
1992 prompt_template,
1993 actions_doc,
1994 );
1995 }
1996
1997 let extractor_tag = match extractor_tag {
1999 Some(t) => t,
2000 None => {
2001 return syn::Error::new(
2002 input.ident.span(),
2003 "`#[intent(...)]` attribute must include `extractor_tag = \"...\"`",
2004 )
2005 .to_compile_error()
2006 .into();
2007 }
2008 };
2009
2010 let enum_name = &input.ident;
2012 let enum_docs = extract_doc_comments(&input.attrs);
2013
2014 let mut intents_doc_lines = Vec::new();
2015
2016 if !enum_docs.is_empty() {
2018 intents_doc_lines.push(format!("{}: {}", enum_name, enum_docs));
2019 } else {
2020 intents_doc_lines.push(format!("{}:", enum_name));
2021 }
2022 intents_doc_lines.push(String::new()); intents_doc_lines.push("Possible values:".to_string());
2024
2025 for variant in &enum_data.variants {
2027 let variant_name = &variant.ident;
2028 let variant_docs = extract_doc_comments(&variant.attrs);
2029
2030 if !variant_docs.is_empty() {
2031 intents_doc_lines.push(format!("- {}: {}", variant_name, variant_docs));
2032 } else {
2033 intents_doc_lines.push(format!("- {}", variant_name));
2034 }
2035 }
2036
2037 let intents_doc_str = intents_doc_lines.join("\n");
2038
2039 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2041 let user_variables: Vec<String> = placeholders
2042 .iter()
2043 .filter_map(|(name, _)| {
2044 if name != "intents_doc" {
2045 Some(name.clone())
2046 } else {
2047 None
2048 }
2049 })
2050 .collect();
2051
2052 let enum_name_str = enum_name.to_string();
2054 let snake_case_name = to_snake_case(&enum_name_str);
2055 let function_name = syn::Ident::new(
2056 &format!("build_{}_prompt", snake_case_name),
2057 proc_macro2::Span::call_site(),
2058 );
2059
2060 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2062 .iter()
2063 .map(|var| {
2064 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2065 quote! { #ident: &str }
2066 })
2067 .collect();
2068
2069 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2071 .iter()
2072 .map(|var| {
2073 let var_str = var.clone();
2074 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2075 quote! {
2076 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2077 }
2078 })
2079 .collect();
2080
2081 let converted_template = prompt_template.clone();
2083
2084 let extractor_name = syn::Ident::new(
2086 &format!("{}Extractor", enum_name),
2087 proc_macro2::Span::call_site(),
2088 );
2089
2090 let filtered_attrs: Vec<_> = input
2092 .attrs
2093 .iter()
2094 .filter(|attr| !attr.path().is_ident("intent"))
2095 .collect();
2096
2097 let vis = &input.vis;
2099 let generics = &input.generics;
2100 let variants = &enum_data.variants;
2101 let enum_output = quote! {
2102 #(#filtered_attrs)*
2103 #vis enum #enum_name #generics {
2104 #variants
2105 }
2106 };
2107
2108 let expanded = quote! {
2110 #enum_output
2112
2113 pub fn #function_name(#(#function_params),*) -> String {
2115 let mut env = minijinja::Environment::new();
2116 env.add_template("prompt", #converted_template)
2117 .expect("Failed to parse intent prompt template");
2118
2119 let tmpl = env.get_template("prompt").unwrap();
2120
2121 let mut __template_context = std::collections::HashMap::new();
2122
2123 __template_context.insert("intents_doc".to_string(), minijinja::Value::from(#intents_doc_str));
2125
2126 #(#context_insertions)*
2128
2129 tmpl.render(&__template_context)
2130 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2131 }
2132
2133 pub struct #extractor_name;
2135
2136 impl #extractor_name {
2137 pub const EXTRACTOR_TAG: &'static str = #extractor_tag;
2138 }
2139
2140 impl #crate_path::intent::IntentExtractor<#enum_name> for #extractor_name {
2141 fn extract_intent(&self, response: &str) -> Result<#enum_name, #crate_path::intent::IntentExtractionError> {
2142 #crate_path::intent::extract_intent_from_response(response, Self::EXTRACTOR_TAG)
2144 }
2145 }
2146 };
2147
2148 TokenStream::from(expanded)
2149}
2150
2151fn to_snake_case(s: &str) -> String {
2153 let mut result = String::new();
2154 let mut prev_upper = false;
2155
2156 for (i, ch) in s.chars().enumerate() {
2157 if ch.is_uppercase() {
2158 if i > 0 && !prev_upper {
2159 result.push('_');
2160 }
2161 result.push(ch.to_lowercase().next().unwrap());
2162 prev_upper = true;
2163 } else {
2164 result.push(ch);
2165 prev_upper = false;
2166 }
2167 }
2168
2169 result
2170}
2171
2172#[derive(Debug, Default)]
2174struct ActionAttrs {
2175 tag: Option<String>,
2176}
2177
2178fn parse_action_attrs(attrs: &[syn::Attribute]) -> ActionAttrs {
2179 let mut result = ActionAttrs::default();
2180
2181 for attr in attrs {
2182 if attr.path().is_ident("action")
2183 && let Ok(meta_list) = attr.meta.require_list()
2184 && let Ok(metas) =
2185 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2186 {
2187 for meta in metas {
2188 if let Meta::NameValue(nv) = meta
2189 && nv.path.is_ident("tag")
2190 && let syn::Expr::Lit(syn::ExprLit {
2191 lit: syn::Lit::Str(lit_str),
2192 ..
2193 }) = nv.value
2194 {
2195 result.tag = Some(lit_str.value());
2196 }
2197 }
2198 }
2199 }
2200
2201 result
2202}
2203
2204#[derive(Debug, Default)]
2206struct FieldActionAttrs {
2207 is_attribute: bool,
2208 is_inner_text: bool,
2209}
2210
2211fn parse_field_action_attrs(attrs: &[syn::Attribute]) -> FieldActionAttrs {
2212 let mut result = FieldActionAttrs::default();
2213
2214 for attr in attrs {
2215 if attr.path().is_ident("action")
2216 && let Ok(meta_list) = attr.meta.require_list()
2217 {
2218 let tokens_str = meta_list.tokens.to_string();
2219 if tokens_str == "attribute" {
2220 result.is_attribute = true;
2221 } else if tokens_str == "inner_text" {
2222 result.is_inner_text = true;
2223 }
2224 }
2225 }
2226
2227 result
2228}
2229
2230fn generate_multi_tag_actions_doc(
2232 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2233) -> String {
2234 let mut doc_lines = Vec::new();
2235
2236 for variant in variants {
2237 let action_attrs = parse_action_attrs(&variant.attrs);
2238
2239 if let Some(tag) = action_attrs.tag {
2240 let variant_docs = extract_doc_comments(&variant.attrs);
2241
2242 match &variant.fields {
2243 syn::Fields::Unit => {
2244 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2246 }
2247 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
2248 doc_lines.push(format!("- `<{}>...</{}>`: {}", tag, tag, variant_docs));
2250 }
2251 syn::Fields::Named(fields) => {
2252 let mut attrs_str = Vec::new();
2254 let mut has_inner_text = false;
2255
2256 for field in &fields.named {
2257 let field_name = field.ident.as_ref().unwrap();
2258 let field_attrs = parse_field_action_attrs(&field.attrs);
2259
2260 if field_attrs.is_attribute {
2261 attrs_str.push(format!("{}=\"...\"", field_name));
2262 } else if field_attrs.is_inner_text {
2263 has_inner_text = true;
2264 }
2265 }
2266
2267 let attrs_part = if !attrs_str.is_empty() {
2268 format!(" {}", attrs_str.join(" "))
2269 } else {
2270 String::new()
2271 };
2272
2273 if has_inner_text {
2274 doc_lines.push(format!(
2275 "- `<{}{}>...</{}>`: {}",
2276 tag, attrs_part, tag, variant_docs
2277 ));
2278 } else if !attrs_str.is_empty() {
2279 doc_lines.push(format!("- `<{}{} />`: {}", tag, attrs_part, variant_docs));
2280 } else {
2281 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2282 }
2283
2284 for field in &fields.named {
2286 let field_name = field.ident.as_ref().unwrap();
2287 let field_attrs = parse_field_action_attrs(&field.attrs);
2288 let field_docs = extract_doc_comments(&field.attrs);
2289
2290 if field_attrs.is_attribute {
2291 doc_lines
2292 .push(format!(" - `{}` (attribute): {}", field_name, field_docs));
2293 } else if field_attrs.is_inner_text {
2294 doc_lines
2295 .push(format!(" - `{}` (inner_text): {}", field_name, field_docs));
2296 }
2297 }
2298 }
2299 _ => {
2300 }
2302 }
2303 }
2304 }
2305
2306 doc_lines.join("\n")
2307}
2308
2309fn generate_tags_regex(
2311 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2312) -> String {
2313 let mut tag_names = Vec::new();
2314
2315 for variant in variants {
2316 let action_attrs = parse_action_attrs(&variant.attrs);
2317 if let Some(tag) = action_attrs.tag {
2318 tag_names.push(tag);
2319 }
2320 }
2321
2322 if tag_names.is_empty() {
2323 return String::new();
2324 }
2325
2326 let tags_pattern = tag_names.join("|");
2327 format!(
2330 r"(?is)<(?:{})\b[^>]*/>|<(?:{})\b[^>]*>.*?</(?:{})>",
2331 tags_pattern, tags_pattern, tags_pattern
2332 )
2333}
2334
2335fn generate_multi_tag_output(
2337 input: &DeriveInput,
2338 enum_name: &syn::Ident,
2339 enum_data: &syn::DataEnum,
2340 prompt_template: String,
2341 actions_doc: String,
2342) -> TokenStream {
2343 let found_crate =
2344 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2345 let crate_path = match found_crate {
2346 FoundCrate::Itself => {
2347 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2349 quote!(::#ident)
2350 }
2351 FoundCrate::Name(name) => {
2352 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2353 quote!(::#ident)
2354 }
2355 };
2356
2357 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2359 let user_variables: Vec<String> = placeholders
2360 .iter()
2361 .filter_map(|(name, _)| {
2362 if name != "actions_doc" {
2363 Some(name.clone())
2364 } else {
2365 None
2366 }
2367 })
2368 .collect();
2369
2370 let enum_name_str = enum_name.to_string();
2372 let snake_case_name = to_snake_case(&enum_name_str);
2373 let function_name = syn::Ident::new(
2374 &format!("build_{}_prompt", snake_case_name),
2375 proc_macro2::Span::call_site(),
2376 );
2377
2378 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2380 .iter()
2381 .map(|var| {
2382 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2383 quote! { #ident: &str }
2384 })
2385 .collect();
2386
2387 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2389 .iter()
2390 .map(|var| {
2391 let var_str = var.clone();
2392 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2393 quote! {
2394 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2395 }
2396 })
2397 .collect();
2398
2399 let extractor_name = syn::Ident::new(
2401 &format!("{}Extractor", enum_name),
2402 proc_macro2::Span::call_site(),
2403 );
2404
2405 let filtered_attrs: Vec<_> = input
2407 .attrs
2408 .iter()
2409 .filter(|attr| !attr.path().is_ident("intent"))
2410 .collect();
2411
2412 let filtered_variants: Vec<proc_macro2::TokenStream> = enum_data
2414 .variants
2415 .iter()
2416 .map(|variant| {
2417 let variant_name = &variant.ident;
2418 let variant_attrs: Vec<_> = variant
2419 .attrs
2420 .iter()
2421 .filter(|attr| !attr.path().is_ident("action"))
2422 .collect();
2423 let fields = &variant.fields;
2424
2425 let filtered_fields = match fields {
2427 syn::Fields::Named(named_fields) => {
2428 let filtered: Vec<_> = named_fields
2429 .named
2430 .iter()
2431 .map(|field| {
2432 let field_name = &field.ident;
2433 let field_type = &field.ty;
2434 let field_vis = &field.vis;
2435 let filtered_attrs: Vec<_> = field
2436 .attrs
2437 .iter()
2438 .filter(|attr| !attr.path().is_ident("action"))
2439 .collect();
2440 quote! {
2441 #(#filtered_attrs)*
2442 #field_vis #field_name: #field_type
2443 }
2444 })
2445 .collect();
2446 quote! { { #(#filtered,)* } }
2447 }
2448 syn::Fields::Unnamed(unnamed_fields) => {
2449 let types: Vec<_> = unnamed_fields
2450 .unnamed
2451 .iter()
2452 .map(|field| {
2453 let field_type = &field.ty;
2454 quote! { #field_type }
2455 })
2456 .collect();
2457 quote! { (#(#types),*) }
2458 }
2459 syn::Fields::Unit => quote! {},
2460 };
2461
2462 quote! {
2463 #(#variant_attrs)*
2464 #variant_name #filtered_fields
2465 }
2466 })
2467 .collect();
2468
2469 let vis = &input.vis;
2470 let generics = &input.generics;
2471
2472 let parsing_arms = generate_parsing_arms(&enum_data.variants, enum_name);
2474
2475 let tags_regex = generate_tags_regex(&enum_data.variants);
2477
2478 let expanded = quote! {
2479 #(#filtered_attrs)*
2481 #vis enum #enum_name #generics {
2482 #(#filtered_variants),*
2483 }
2484
2485 pub fn #function_name(#(#function_params),*) -> String {
2487 let mut env = minijinja::Environment::new();
2488 env.add_template("prompt", #prompt_template)
2489 .expect("Failed to parse intent prompt template");
2490
2491 let tmpl = env.get_template("prompt").unwrap();
2492
2493 let mut __template_context = std::collections::HashMap::new();
2494
2495 __template_context.insert("actions_doc".to_string(), minijinja::Value::from(#actions_doc));
2497
2498 #(#context_insertions)*
2500
2501 tmpl.render(&__template_context)
2502 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2503 }
2504
2505 pub struct #extractor_name;
2507
2508 impl #extractor_name {
2509 fn parse_single_action(&self, text: &str) -> Option<#enum_name> {
2510 use ::quick_xml::events::Event;
2511 use ::quick_xml::Reader;
2512
2513 let mut actions = Vec::new();
2514 let mut reader = Reader::from_str(text);
2515 reader.config_mut().trim_text(true);
2516
2517 let mut buf = Vec::new();
2518
2519 loop {
2520 match reader.read_event_into(&mut buf) {
2521 Ok(Event::Start(e)) => {
2522 let owned_e = e.into_owned();
2523 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2524 let is_empty = false;
2525
2526 #parsing_arms
2527 }
2528 Ok(Event::Empty(e)) => {
2529 let owned_e = e.into_owned();
2530 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2531 let is_empty = true;
2532
2533 #parsing_arms
2534 }
2535 Ok(Event::Eof) => break,
2536 Err(_) => {
2537 break;
2539 }
2540 _ => {}
2541 }
2542 buf.clear();
2543 }
2544
2545 actions.into_iter().next()
2546 }
2547
2548 pub fn extract_actions(&self, text: &str) -> Result<Vec<#enum_name>, #crate_path::intent::IntentError> {
2549 use ::quick_xml::events::Event;
2550 use ::quick_xml::Reader;
2551
2552 let mut actions = Vec::new();
2553 let mut reader = Reader::from_str(text);
2554 reader.config_mut().trim_text(true);
2555
2556 let mut buf = Vec::new();
2557
2558 loop {
2559 match reader.read_event_into(&mut buf) {
2560 Ok(Event::Start(e)) => {
2561 let owned_e = e.into_owned();
2562 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2563 let is_empty = false;
2564
2565 #parsing_arms
2566 }
2567 Ok(Event::Empty(e)) => {
2568 let owned_e = e.into_owned();
2569 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2570 let is_empty = true;
2571
2572 #parsing_arms
2573 }
2574 Ok(Event::Eof) => break,
2575 Err(_) => {
2576 break;
2578 }
2579 _ => {}
2580 }
2581 buf.clear();
2582 }
2583
2584 Ok(actions)
2585 }
2586
2587 pub fn transform_actions<F>(&self, text: &str, mut transformer: F) -> String
2588 where
2589 F: FnMut(#enum_name) -> String,
2590 {
2591 use ::regex::Regex;
2592
2593 let regex_pattern = #tags_regex;
2594 if regex_pattern.is_empty() {
2595 return text.to_string();
2596 }
2597
2598 let re = Regex::new(®ex_pattern).unwrap_or_else(|e| {
2599 panic!("Failed to compile regex for action tags: {}", e);
2600 });
2601
2602 re.replace_all(text, |caps: &::regex::Captures| {
2603 let matched = caps.get(0).map(|m| m.as_str()).unwrap_or("");
2604
2605 if let Some(action) = self.parse_single_action(matched) {
2607 transformer(action)
2608 } else {
2609 matched.to_string()
2611 }
2612 }).to_string()
2613 }
2614
2615 pub fn strip_actions(&self, text: &str) -> String {
2616 self.transform_actions(text, |_| String::new())
2617 }
2618 }
2619 };
2620
2621 TokenStream::from(expanded)
2622}
2623
2624fn generate_parsing_arms(
2626 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2627 enum_name: &syn::Ident,
2628) -> proc_macro2::TokenStream {
2629 let mut arms = Vec::new();
2630
2631 for variant in variants {
2632 let variant_name = &variant.ident;
2633 let action_attrs = parse_action_attrs(&variant.attrs);
2634
2635 if let Some(tag) = action_attrs.tag {
2636 match &variant.fields {
2637 syn::Fields::Unit => {
2638 arms.push(quote! {
2640 if &tag_name == #tag {
2641 actions.push(#enum_name::#variant_name);
2642 }
2643 });
2644 }
2645 syn::Fields::Unnamed(_fields) => {
2646 arms.push(quote! {
2648 if &tag_name == #tag && !is_empty {
2649 match reader.read_text(owned_e.name()) {
2651 Ok(text) => {
2652 actions.push(#enum_name::#variant_name(text.to_string()));
2653 }
2654 Err(_) => {
2655 actions.push(#enum_name::#variant_name(String::new()));
2657 }
2658 }
2659 }
2660 });
2661 }
2662 syn::Fields::Named(fields) => {
2663 let mut field_names = Vec::new();
2665 let mut has_inner_text_field = None;
2666
2667 for field in &fields.named {
2668 let field_name = field.ident.as_ref().unwrap();
2669 let field_attrs = parse_field_action_attrs(&field.attrs);
2670
2671 if field_attrs.is_attribute {
2672 field_names.push(field_name.clone());
2673 } else if field_attrs.is_inner_text {
2674 has_inner_text_field = Some(field_name.clone());
2675 }
2676 }
2677
2678 if let Some(inner_text_field) = has_inner_text_field {
2679 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2682 quote! {
2683 let mut #field_name = String::new();
2684 for attr in owned_e.attributes() {
2685 if let Ok(attr) = attr {
2686 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2687 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2688 break;
2689 }
2690 }
2691 }
2692 }
2693 }).collect();
2694
2695 arms.push(quote! {
2696 if &tag_name == #tag {
2697 #(#attr_extractions)*
2698
2699 if is_empty {
2701 let #inner_text_field = String::new();
2702 actions.push(#enum_name::#variant_name {
2703 #(#field_names,)*
2704 #inner_text_field,
2705 });
2706 } else {
2707 match reader.read_text(owned_e.name()) {
2709 Ok(text) => {
2710 let #inner_text_field = text.to_string();
2711 actions.push(#enum_name::#variant_name {
2712 #(#field_names,)*
2713 #inner_text_field,
2714 });
2715 }
2716 Err(_) => {
2717 let #inner_text_field = String::new();
2719 actions.push(#enum_name::#variant_name {
2720 #(#field_names,)*
2721 #inner_text_field,
2722 });
2723 }
2724 }
2725 }
2726 }
2727 });
2728 } else {
2729 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2731 quote! {
2732 let mut #field_name = String::new();
2733 for attr in owned_e.attributes() {
2734 if let Ok(attr) = attr {
2735 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2736 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2737 break;
2738 }
2739 }
2740 }
2741 }
2742 }).collect();
2743
2744 arms.push(quote! {
2745 if &tag_name == #tag {
2746 #(#attr_extractions)*
2747 actions.push(#enum_name::#variant_name {
2748 #(#field_names),*
2749 });
2750 }
2751 });
2752 }
2753 }
2754 }
2755 }
2756 }
2757
2758 quote! {
2759 #(#arms)*
2760 }
2761}
2762
2763#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
2765pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
2766 let input = parse_macro_input!(input as DeriveInput);
2767
2768 let found_crate =
2769 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2770 let crate_path = match found_crate {
2771 FoundCrate::Itself => {
2772 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2774 quote!(::#ident)
2775 }
2776 FoundCrate::Name(name) => {
2777 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2778 quote!(::#ident)
2779 }
2780 };
2781
2782 let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
2784
2785 let struct_name = &input.ident;
2786 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2787
2788 let placeholders = parse_template_placeholders_with_mode(&template);
2790
2791 let mut converted_template = template.clone();
2793 let mut context_fields = Vec::new();
2794
2795 let fields = match &input.data {
2797 Data::Struct(data_struct) => match &data_struct.fields {
2798 syn::Fields::Named(fields) => &fields.named,
2799 _ => panic!("ToPromptFor is only supported for structs with named fields"),
2800 },
2801 _ => panic!("ToPromptFor is only supported for structs"),
2802 };
2803
2804 let has_mode_support = input.attrs.iter().any(|attr| {
2806 if attr.path().is_ident("prompt")
2807 && let Ok(metas) =
2808 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2809 {
2810 for meta in metas {
2811 if let Meta::NameValue(nv) = meta
2812 && nv.path.is_ident("mode")
2813 {
2814 return true;
2815 }
2816 }
2817 }
2818 false
2819 });
2820
2821 for (placeholder_name, mode_opt) in &placeholders {
2823 if placeholder_name == "self" {
2824 if let Some(specific_mode) = mode_opt {
2825 let unique_key = format!("self__{}", specific_mode);
2827
2828 let pattern = format!("{{{{ self:{} }}}}", specific_mode);
2830 let replacement = format!("{{{{ {} }}}}", unique_key);
2831 converted_template = converted_template.replace(&pattern, &replacement);
2832
2833 context_fields.push(quote! {
2835 context.insert(
2836 #unique_key.to_string(),
2837 minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
2838 );
2839 });
2840 } else {
2841 if has_mode_support {
2844 context_fields.push(quote! {
2846 context.insert(
2847 "self".to_string(),
2848 minijinja::Value::from(self.to_prompt_with_mode(mode))
2849 );
2850 });
2851 } else {
2852 context_fields.push(quote! {
2854 context.insert(
2855 "self".to_string(),
2856 minijinja::Value::from(self.to_prompt())
2857 );
2858 });
2859 }
2860 }
2861 } else {
2862 let field_exists = fields.iter().any(|f| {
2865 f.ident
2866 .as_ref()
2867 .is_some_and(|ident| ident == placeholder_name)
2868 });
2869
2870 if field_exists {
2871 let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
2872
2873 context_fields.push(quote! {
2877 context.insert(
2878 #placeholder_name.to_string(),
2879 minijinja::Value::from_serialize(&self.#field_ident)
2880 );
2881 });
2882 }
2883 }
2885 }
2886
2887 let expanded = quote! {
2888 impl #impl_generics #crate_path::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
2889 where
2890 #target_type: serde::Serialize,
2891 {
2892 fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
2893 let mut env = minijinja::Environment::new();
2895 env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
2896 panic!("Failed to parse template: {}", e)
2897 });
2898
2899 let tmpl = env.get_template("prompt").unwrap();
2900
2901 let mut context = std::collections::HashMap::new();
2903 context.insert(
2905 "self".to_string(),
2906 minijinja::Value::from_serialize(self)
2907 );
2908 context.insert(
2910 "target".to_string(),
2911 minijinja::Value::from_serialize(target)
2912 );
2913 #(#context_fields)*
2914
2915 tmpl.render(context).unwrap_or_else(|e| {
2917 format!("Failed to render prompt: {}", e)
2918 })
2919 }
2920 }
2921 };
2922
2923 TokenStream::from(expanded)
2924}
2925
2926struct AgentAttrs {
2932 expertise: Option<String>,
2933 output: Option<syn::Type>,
2934 backend: Option<String>,
2935 model: Option<String>,
2936 inner: Option<String>,
2937 default_inner: Option<String>,
2938 max_retries: Option<u32>,
2939 profile: Option<String>,
2940}
2941
2942impl Parse for AgentAttrs {
2943 fn parse(input: ParseStream) -> syn::Result<Self> {
2944 let mut expertise = None;
2945 let mut output = None;
2946 let mut backend = None;
2947 let mut model = None;
2948 let mut inner = None;
2949 let mut default_inner = None;
2950 let mut max_retries = None;
2951 let mut profile = None;
2952
2953 let pairs = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
2954
2955 for meta in pairs {
2956 match meta {
2957 Meta::NameValue(nv) if nv.path.is_ident("expertise") => {
2958 if let syn::Expr::Lit(syn::ExprLit {
2959 lit: syn::Lit::Str(lit_str),
2960 ..
2961 }) = &nv.value
2962 {
2963 expertise = Some(lit_str.value());
2964 }
2965 }
2966 Meta::NameValue(nv) if nv.path.is_ident("output") => {
2967 if let syn::Expr::Lit(syn::ExprLit {
2968 lit: syn::Lit::Str(lit_str),
2969 ..
2970 }) = &nv.value
2971 {
2972 let ty: syn::Type = syn::parse_str(&lit_str.value())?;
2973 output = Some(ty);
2974 }
2975 }
2976 Meta::NameValue(nv) if nv.path.is_ident("backend") => {
2977 if let syn::Expr::Lit(syn::ExprLit {
2978 lit: syn::Lit::Str(lit_str),
2979 ..
2980 }) = &nv.value
2981 {
2982 backend = Some(lit_str.value());
2983 }
2984 }
2985 Meta::NameValue(nv) if nv.path.is_ident("model") => {
2986 if let syn::Expr::Lit(syn::ExprLit {
2987 lit: syn::Lit::Str(lit_str),
2988 ..
2989 }) = &nv.value
2990 {
2991 model = Some(lit_str.value());
2992 }
2993 }
2994 Meta::NameValue(nv) if nv.path.is_ident("inner") => {
2995 if let syn::Expr::Lit(syn::ExprLit {
2996 lit: syn::Lit::Str(lit_str),
2997 ..
2998 }) = &nv.value
2999 {
3000 inner = Some(lit_str.value());
3001 }
3002 }
3003 Meta::NameValue(nv) if nv.path.is_ident("default_inner") => {
3004 if let syn::Expr::Lit(syn::ExprLit {
3005 lit: syn::Lit::Str(lit_str),
3006 ..
3007 }) = &nv.value
3008 {
3009 default_inner = Some(lit_str.value());
3010 }
3011 }
3012 Meta::NameValue(nv) if nv.path.is_ident("max_retries") => {
3013 if let syn::Expr::Lit(syn::ExprLit {
3014 lit: syn::Lit::Int(lit_int),
3015 ..
3016 }) = &nv.value
3017 {
3018 max_retries = Some(lit_int.base10_parse()?);
3019 }
3020 }
3021 Meta::NameValue(nv) if nv.path.is_ident("profile") => {
3022 if let syn::Expr::Lit(syn::ExprLit {
3023 lit: syn::Lit::Str(lit_str),
3024 ..
3025 }) = &nv.value
3026 {
3027 profile = Some(lit_str.value());
3028 }
3029 }
3030 _ => {}
3031 }
3032 }
3033
3034 Ok(AgentAttrs {
3035 expertise,
3036 output,
3037 backend,
3038 model,
3039 inner,
3040 default_inner,
3041 max_retries,
3042 profile,
3043 })
3044 }
3045}
3046
3047fn parse_agent_attrs(attrs: &[syn::Attribute]) -> syn::Result<AgentAttrs> {
3049 for attr in attrs {
3050 if attr.path().is_ident("agent") {
3051 return attr.parse_args::<AgentAttrs>();
3052 }
3053 }
3054
3055 Ok(AgentAttrs {
3056 expertise: None,
3057 output: None,
3058 backend: None,
3059 model: None,
3060 inner: None,
3061 default_inner: None,
3062 max_retries: None,
3063 profile: None,
3064 })
3065}
3066
3067fn generate_backend_constructors(
3069 struct_name: &syn::Ident,
3070 backend: &str,
3071 _model: Option<&str>,
3072 _profile: Option<&str>,
3073 crate_path: &proc_macro2::TokenStream,
3074) -> proc_macro2::TokenStream {
3075 match backend {
3076 "claude" => {
3077 quote! {
3078 impl #struct_name {
3079 pub fn with_claude() -> Self {
3081 Self::new(#crate_path::agent::impls::ClaudeCodeAgent::new())
3082 }
3083
3084 pub fn with_claude_model(model: &str) -> Self {
3086 Self::new(
3087 #crate_path::agent::impls::ClaudeCodeAgent::new()
3088 .with_model_str(model)
3089 )
3090 }
3091 }
3092 }
3093 }
3094 "gemini" => {
3095 quote! {
3096 impl #struct_name {
3097 pub fn with_gemini() -> Self {
3099 Self::new(#crate_path::agent::impls::GeminiAgent::new())
3100 }
3101
3102 pub fn with_gemini_model(model: &str) -> Self {
3104 Self::new(
3105 #crate_path::agent::impls::GeminiAgent::new()
3106 .with_model_str(model)
3107 )
3108 }
3109 }
3110 }
3111 }
3112 _ => quote! {},
3113 }
3114}
3115
3116fn generate_default_impl(
3118 struct_name: &syn::Ident,
3119 backend: &str,
3120 model: Option<&str>,
3121 profile: Option<&str>,
3122 crate_path: &proc_macro2::TokenStream,
3123) -> proc_macro2::TokenStream {
3124 let profile_expr = if let Some(profile_str) = profile {
3126 match profile_str.to_lowercase().as_str() {
3127 "creative" => quote! { #crate_path::agent::ExecutionProfile::Creative },
3128 "balanced" => quote! { #crate_path::agent::ExecutionProfile::Balanced },
3129 "deterministic" => quote! { #crate_path::agent::ExecutionProfile::Deterministic },
3130 _ => quote! { #crate_path::agent::ExecutionProfile::Balanced }, }
3132 } else {
3133 quote! { #crate_path::agent::ExecutionProfile::default() }
3134 };
3135
3136 let agent_init = match backend {
3137 "gemini" => {
3138 let mut builder = quote! { #crate_path::agent::impls::GeminiAgent::new() };
3139
3140 if let Some(model_str) = model {
3141 builder = quote! { #builder.with_model_str(#model_str) };
3142 }
3143
3144 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3145 builder
3146 }
3147 _ => {
3148 let mut builder = quote! { #crate_path::agent::impls::ClaudeCodeAgent::new() };
3150
3151 if let Some(model_str) = model {
3152 builder = quote! { #builder.with_model_str(#model_str) };
3153 }
3154
3155 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3156 builder
3157 }
3158 };
3159
3160 quote! {
3161 impl Default for #struct_name {
3162 fn default() -> Self {
3163 Self::new(#agent_init)
3164 }
3165 }
3166 }
3167}
3168
3169#[proc_macro_derive(Agent, attributes(agent))]
3178pub fn derive_agent(input: TokenStream) -> TokenStream {
3179 let input = parse_macro_input!(input as DeriveInput);
3180 let struct_name = &input.ident;
3181
3182 let agent_attrs = match parse_agent_attrs(&input.attrs) {
3184 Ok(attrs) => attrs,
3185 Err(e) => return e.to_compile_error().into(),
3186 };
3187
3188 let expertise = agent_attrs
3189 .expertise
3190 .unwrap_or_else(|| String::from("general AI assistant"));
3191 let output_type = agent_attrs
3192 .output
3193 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3194 let backend = agent_attrs
3195 .backend
3196 .unwrap_or_else(|| String::from("claude"));
3197 let model = agent_attrs.model;
3198 let _profile = agent_attrs.profile; let max_retries = agent_attrs.max_retries.unwrap_or(3); let found_crate =
3203 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3204 let crate_path = match found_crate {
3205 FoundCrate::Itself => {
3206 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3208 quote!(::#ident)
3209 }
3210 FoundCrate::Name(name) => {
3211 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3212 quote!(::#ident)
3213 }
3214 };
3215
3216 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
3217
3218 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3220 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3221
3222 let enhanced_expertise = if is_string_output {
3224 quote! { #expertise }
3226 } else {
3227 let type_name = quote!(#output_type).to_string();
3229 quote! {
3230 {
3231 use std::sync::OnceLock;
3232 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3233
3234 EXPERTISE_CACHE.get_or_init(|| {
3235 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3237
3238 if schema.is_empty() {
3239 format!(
3241 concat!(
3242 #expertise,
3243 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3244 "Do not include any text outside the JSON object."
3245 ),
3246 #type_name
3247 )
3248 } else {
3249 format!(
3251 concat!(
3252 #expertise,
3253 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3254 ),
3255 schema
3256 )
3257 }
3258 }).as_str()
3259 }
3260 }
3261 };
3262
3263 let agent_init = match backend.as_str() {
3265 "gemini" => {
3266 if let Some(model_str) = model {
3267 quote! {
3268 use #crate_path::agent::impls::GeminiAgent;
3269 let agent = GeminiAgent::new().with_model_str(#model_str);
3270 }
3271 } else {
3272 quote! {
3273 use #crate_path::agent::impls::GeminiAgent;
3274 let agent = GeminiAgent::new();
3275 }
3276 }
3277 }
3278 "claude" => {
3279 if let Some(model_str) = model {
3280 quote! {
3281 use #crate_path::agent::impls::ClaudeCodeAgent;
3282 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3283 }
3284 } else {
3285 quote! {
3286 use #crate_path::agent::impls::ClaudeCodeAgent;
3287 let agent = ClaudeCodeAgent::new();
3288 }
3289 }
3290 }
3291 _ => {
3292 if let Some(model_str) = model {
3294 quote! {
3295 use #crate_path::agent::impls::ClaudeCodeAgent;
3296 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3297 }
3298 } else {
3299 quote! {
3300 use #crate_path::agent::impls::ClaudeCodeAgent;
3301 let agent = ClaudeCodeAgent::new();
3302 }
3303 }
3304 }
3305 };
3306
3307 let expanded = quote! {
3308 #[async_trait::async_trait]
3309 impl #impl_generics #crate_path::agent::Agent for #struct_name #ty_generics #where_clause {
3310 type Output = #output_type;
3311
3312 fn expertise(&self) -> &str {
3313 #enhanced_expertise
3314 }
3315
3316 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3317 #agent_init
3319
3320 let max_retries: u32 = #max_retries;
3322 let mut attempts = 0u32;
3323
3324 loop {
3325 attempts += 1;
3326
3327 let result = async {
3329 let response = agent.execute(intent.clone()).await?;
3330
3331 let json_str = #crate_path::extract_json(&response)
3333 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))?;
3334
3335 serde_json::from_str::<Self::Output>(&json_str)
3337 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))
3338 }.await;
3339
3340 match result {
3341 Ok(output) => return Ok(output),
3342 Err(e) if e.is_retryable() && attempts < max_retries => {
3343 log::warn!(
3345 "Agent execution failed (attempt {}/{}): {}. Retrying...",
3346 attempts,
3347 max_retries,
3348 e
3349 );
3350
3351 let delay_ms = 100 * attempts as u64;
3353 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
3354
3355 continue;
3357 }
3358 Err(e) => {
3359 if attempts > 1 {
3360 log::error!(
3361 "Agent execution failed after {} attempts: {}",
3362 attempts,
3363 e
3364 );
3365 }
3366 return Err(e);
3367 }
3368 }
3369 }
3370 }
3371
3372 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3373 #agent_init
3375 agent.is_available().await
3376 }
3377 }
3378 };
3379
3380 TokenStream::from(expanded)
3381}
3382
3383#[proc_macro_attribute]
3398pub fn agent(attr: TokenStream, item: TokenStream) -> TokenStream {
3399 let agent_attrs = match syn::parse::<AgentAttrs>(attr) {
3401 Ok(attrs) => attrs,
3402 Err(e) => return e.to_compile_error().into(),
3403 };
3404
3405 let input = parse_macro_input!(item as DeriveInput);
3407 let struct_name = &input.ident;
3408 let vis = &input.vis;
3409
3410 let expertise = agent_attrs
3411 .expertise
3412 .unwrap_or_else(|| String::from("general AI assistant"));
3413 let output_type = agent_attrs
3414 .output
3415 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3416 let backend = agent_attrs
3417 .backend
3418 .unwrap_or_else(|| String::from("claude"));
3419 let model = agent_attrs.model;
3420 let profile = agent_attrs.profile;
3421
3422 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3424 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3425
3426 let found_crate =
3428 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3429 let crate_path = match found_crate {
3430 FoundCrate::Itself => {
3431 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3432 quote!(::#ident)
3433 }
3434 FoundCrate::Name(name) => {
3435 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3436 quote!(::#ident)
3437 }
3438 };
3439
3440 let inner_generic_name = agent_attrs.inner.unwrap_or_else(|| String::from("A"));
3442 let inner_generic_ident = syn::Ident::new(&inner_generic_name, proc_macro2::Span::call_site());
3443
3444 let default_agent_type = if let Some(ref custom_type) = agent_attrs.default_inner {
3446 let type_path: syn::Type =
3448 syn::parse_str(custom_type).expect("default_inner must be a valid type path");
3449 quote! { #type_path }
3450 } else {
3451 match backend.as_str() {
3453 "gemini" => quote! { #crate_path::agent::impls::GeminiAgent },
3454 _ => quote! { #crate_path::agent::impls::ClaudeCodeAgent },
3455 }
3456 };
3457
3458 let struct_def = quote! {
3460 #vis struct #struct_name<#inner_generic_ident = #default_agent_type> {
3461 inner: #inner_generic_ident,
3462 }
3463 };
3464
3465 let constructors = quote! {
3467 impl<#inner_generic_ident> #struct_name<#inner_generic_ident> {
3468 pub fn new(inner: #inner_generic_ident) -> Self {
3470 Self { inner }
3471 }
3472 }
3473 };
3474
3475 let (backend_constructors, default_impl) = if agent_attrs.default_inner.is_some() {
3477 let default_impl = quote! {
3479 impl Default for #struct_name {
3480 fn default() -> Self {
3481 Self {
3482 inner: <#default_agent_type as Default>::default(),
3483 }
3484 }
3485 }
3486 };
3487 (quote! {}, default_impl)
3488 } else {
3489 let backend_constructors = generate_backend_constructors(
3491 struct_name,
3492 &backend,
3493 model.as_deref(),
3494 profile.as_deref(),
3495 &crate_path,
3496 );
3497 let default_impl = generate_default_impl(
3498 struct_name,
3499 &backend,
3500 model.as_deref(),
3501 profile.as_deref(),
3502 &crate_path,
3503 );
3504 (backend_constructors, default_impl)
3505 };
3506
3507 let enhanced_expertise = if is_string_output {
3509 quote! { #expertise }
3511 } else {
3512 let type_name = quote!(#output_type).to_string();
3514 quote! {
3515 {
3516 use std::sync::OnceLock;
3517 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3518
3519 EXPERTISE_CACHE.get_or_init(|| {
3520 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3522
3523 if schema.is_empty() {
3524 format!(
3526 concat!(
3527 #expertise,
3528 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3529 "Do not include any text outside the JSON object."
3530 ),
3531 #type_name
3532 )
3533 } else {
3534 format!(
3536 concat!(
3537 #expertise,
3538 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3539 ),
3540 schema
3541 )
3542 }
3543 }).as_str()
3544 }
3545 }
3546 };
3547
3548 let agent_impl = quote! {
3550 #[async_trait::async_trait]
3551 impl<#inner_generic_ident> #crate_path::agent::Agent for #struct_name<#inner_generic_ident>
3552 where
3553 #inner_generic_ident: #crate_path::agent::Agent<Output = String>,
3554 {
3555 type Output = #output_type;
3556
3557 fn expertise(&self) -> &str {
3558 #enhanced_expertise
3559 }
3560
3561 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3562 let enhanced_payload = intent.prepend_text(self.expertise());
3564
3565 let response = self.inner.execute(enhanced_payload).await?;
3567
3568 let json_str = #crate_path::extract_json(&response)
3570 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))?;
3571
3572 serde_json::from_str(&json_str)
3574 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))
3575 }
3576
3577 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3578 self.inner.is_available().await
3579 }
3580 }
3581 };
3582
3583 let expanded = quote! {
3584 #struct_def
3585 #constructors
3586 #backend_constructors
3587 #default_impl
3588 #agent_impl
3589 };
3590
3591 TokenStream::from(expanded)
3592}
3593
3594#[proc_macro_derive(TypeMarker)]
3616pub fn derive_type_marker(input: TokenStream) -> TokenStream {
3617 let input = parse_macro_input!(input as DeriveInput);
3618 let struct_name = &input.ident;
3619 let type_name_str = struct_name.to_string();
3620
3621 let found_crate =
3623 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3624 let crate_path = match found_crate {
3625 FoundCrate::Itself => {
3626 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3627 quote!(::#ident)
3628 }
3629 FoundCrate::Name(name) => {
3630 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3631 quote!(::#ident)
3632 }
3633 };
3634
3635 let expanded = quote! {
3636 impl #crate_path::orchestrator::TypeMarker for #struct_name {
3637 const TYPE_NAME: &'static str = #type_name_str;
3638 }
3639 };
3640
3641 TokenStream::from(expanded)
3642}