1use proc_macro::TokenStream;
2use proc_macro_crate::{FoundCrate, crate_name};
3use quote::quote;
4use regex::Regex;
5use syn::{
6 Data, DeriveInput, Meta, Token,
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 punctuated::Punctuated,
10};
11
12fn parse_template_placeholders_with_mode(template: &str) -> Vec<(String, Option<String>)> {
15 let mut placeholders = Vec::new();
16 let mut seen_fields = std::collections::HashSet::new();
17
18 let mode_pattern = Regex::new(r"\{\{\s*(\w+)\s*:\s*(\w+)\s*\}\}").unwrap();
20 for cap in mode_pattern.captures_iter(template) {
21 let field_name = cap[1].to_string();
22 let mode = cap[2].to_string();
23 placeholders.push((field_name.clone(), Some(mode)));
24 seen_fields.insert(field_name);
25 }
26
27 let standard_pattern = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
29 for cap in standard_pattern.captures_iter(template) {
30 let field_name = cap[1].to_string();
31 if !seen_fields.contains(&field_name) {
33 placeholders.push((field_name, None));
34 }
35 }
36
37 placeholders
38}
39
40fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
42 attrs
43 .iter()
44 .filter_map(|attr| {
45 if attr.path().is_ident("doc")
46 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
47 && let syn::Expr::Lit(syn::ExprLit {
48 lit: syn::Lit::Str(lit_str),
49 ..
50 }) = &meta_name_value.value
51 {
52 return Some(lit_str.value());
53 }
54 None
55 })
56 .map(|s| s.trim().to_string())
57 .collect::<Vec<_>>()
58 .join(" ")
59}
60
61fn generate_example_only_parts(
63 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
64 has_default: bool,
65 crate_path: &proc_macro2::TokenStream,
66) -> proc_macro2::TokenStream {
67 let mut field_values = Vec::new();
68
69 for field in fields.iter() {
70 let field_name = field.ident.as_ref().unwrap();
71 let field_name_str = field_name.to_string();
72 let attrs = parse_field_prompt_attrs(&field.attrs);
73
74 if attrs.skip {
76 continue;
77 }
78
79 if let Some(example) = attrs.example {
81 field_values.push(quote! {
83 json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
84 });
85 } else if has_default {
86 field_values.push(quote! {
88 let default_value = serde_json::to_value(&default_instance.#field_name)
89 .unwrap_or(serde_json::Value::Null);
90 json_obj.insert(#field_name_str.to_string(), default_value);
91 });
92 } else {
93 field_values.push(quote! {
95 let value = serde_json::to_value(&self.#field_name)
96 .unwrap_or(serde_json::Value::Null);
97 json_obj.insert(#field_name_str.to_string(), value);
98 });
99 }
100 }
101
102 if has_default {
103 quote! {
104 {
105 let default_instance = Self::default();
106 let mut json_obj = serde_json::Map::new();
107 #(#field_values)*
108 let json_value = serde_json::Value::Object(json_obj);
109 let json_str = serde_json::to_string_pretty(&json_value)
110 .unwrap_or_else(|_| "{}".to_string());
111 vec![#crate_path::prompt::PromptPart::Text(json_str)]
112 }
113 }
114 } else {
115 quote! {
116 {
117 let mut json_obj = serde_json::Map::new();
118 #(#field_values)*
119 let json_value = serde_json::Value::Object(json_obj);
120 let json_str = serde_json::to_string_pretty(&json_value)
121 .unwrap_or_else(|_| "{}".to_string());
122 vec![#crate_path::prompt::PromptPart::Text(json_str)]
123 }
124 }
125 }
126}
127
128fn generate_schema_only_parts(
130 struct_name: &str,
131 struct_docs: &str,
132 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
133 crate_path: &proc_macro2::TokenStream,
134 has_type_marker: bool,
135) -> proc_macro2::TokenStream {
136 let mut field_schema_parts = vec![];
137
138 if has_type_marker {
140 field_schema_parts.push(quote! {
141 format!(" \"__type\": \"string\",")
142 });
143 }
144
145 for (i, field) in fields.iter().enumerate() {
147 let field_name = field.ident.as_ref().unwrap();
148 let field_name_str = field_name.to_string();
149 let attrs = parse_field_prompt_attrs(&field.attrs);
150
151 if attrs.skip {
153 continue;
154 }
155
156 let field_docs = extract_doc_comments(&field.attrs);
158
159 let (is_vec, inner_type) = extract_vec_inner_type(&field.ty);
161
162 let remaining_fields = fields
164 .iter()
165 .skip(i + 1)
166 .filter(|f| {
167 let attrs = parse_field_prompt_attrs(&f.attrs);
168 !attrs.skip
169 })
170 .count();
171 let comma = if remaining_fields > 0 { "," } else { "" };
172
173 if is_vec {
174 let comment = if !field_docs.is_empty() {
176 format!(", // {}", field_docs)
177 } else {
178 String::new()
179 };
180
181 field_schema_parts.push(quote! {
182 {
183 let inner_schema = <#inner_type as #crate_path::prompt::ToPrompt>::prompt_schema();
184 if inner_schema.is_empty() {
185 format!(" \"{}\": \"{}[]\"{}{}", #field_name_str, stringify!(#inner_type).to_lowercase(), #comment, #comma)
187 } else {
188 let inner_lines: Vec<&str> = inner_schema.lines()
190 .skip_while(|line| line.starts_with("###") || line.trim() == "{")
191 .take_while(|line| line.trim() != "}")
192 .collect();
193 let inner_content = inner_lines.join("\n");
194 format!(" \"{}\": [\n {{\n{}\n }}\n ]{}{}",
195 #field_name_str,
196 inner_content.lines()
197 .map(|line| format!(" {}", line))
198 .collect::<Vec<_>>()
199 .join("\n"),
200 #comment,
201 #comma
202 )
203 }
204 }
205 });
206 } else {
207 let field_type = &field.ty;
209 let is_primitive = is_primitive_type(field_type);
210
211 if !is_primitive {
212 let comment = if !field_docs.is_empty() {
214 format!(", // {}", field_docs)
215 } else {
216 String::new()
217 };
218
219 field_schema_parts.push(quote! {
220 {
221 let nested_schema = <#field_type as #crate_path::prompt::ToPrompt>::prompt_schema();
222 if nested_schema.is_empty() {
223 let type_str = stringify!(#field_type).to_lowercase();
225 format!(" \"{}\": \"{}\"{}{}", #field_name_str, type_str, #comment, #comma)
226 } else {
227 let nested_lines: Vec<&str> = nested_schema.lines()
229 .skip_while(|line| line.starts_with("###") || line.trim() == "{")
230 .take_while(|line| line.trim() != "}")
231 .collect();
232
233 if nested_lines.is_empty() {
234 let type_str = stringify!(#field_type).to_lowercase();
236 format!(" \"{}\": \"{}\"{}{}", #field_name_str, type_str, #comment, #comma)
237 } else {
238 let indented_content = nested_lines.iter()
240 .map(|line| format!(" {}", line))
241 .collect::<Vec<_>>()
242 .join("\n");
243 format!(" \"{}\": {{\n{}\n }}{}{}", #field_name_str, indented_content, #comment, #comma)
244 }
245 }
246 }
247 });
248 } else {
249 let type_str = format_type_for_schema(&field.ty);
251 let comment = if !field_docs.is_empty() {
252 format!(", // {}", field_docs)
253 } else {
254 String::new()
255 };
256
257 field_schema_parts.push(quote! {
258 format!(" \"{}\": \"{}\"{}{}", #field_name_str, #type_str, #comment, #comma)
259 });
260 }
261 }
262 }
263
264 let header = if !struct_docs.is_empty() {
266 format!("### Schema for `{}`\n{}", struct_name, struct_docs)
267 } else {
268 format!("### Schema for `{}`", struct_name)
269 };
270
271 quote! {
272 {
273 let mut lines = vec![#header.to_string(), "{".to_string()];
274 #(lines.push(#field_schema_parts);)*
275 lines.push("}".to_string());
276 vec![#crate_path::prompt::PromptPart::Text(lines.join("\n"))]
277 }
278 }
279}
280
281fn extract_vec_inner_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
283 if let syn::Type::Path(type_path) = ty
284 && let Some(last_segment) = type_path.path.segments.last()
285 && last_segment.ident == "Vec"
286 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
287 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
288 {
289 return (true, Some(inner_type));
290 }
291 (false, None)
292}
293
294fn is_primitive_type(ty: &syn::Type) -> bool {
296 if let syn::Type::Path(type_path) = ty
297 && let Some(last_segment) = type_path.path.segments.last()
298 {
299 let type_name = last_segment.ident.to_string();
300 matches!(
301 type_name.as_str(),
302 "String"
303 | "str"
304 | "i8"
305 | "i16"
306 | "i32"
307 | "i64"
308 | "i128"
309 | "isize"
310 | "u8"
311 | "u16"
312 | "u32"
313 | "u64"
314 | "u128"
315 | "usize"
316 | "f32"
317 | "f64"
318 | "bool"
319 | "Vec"
320 | "Option"
321 | "HashMap"
322 | "BTreeMap"
323 | "HashSet"
324 | "BTreeSet"
325 )
326 } else {
327 true
329 }
330}
331
332fn format_type_for_schema(ty: &syn::Type) -> String {
334 match ty {
336 syn::Type::Path(type_path) => {
337 let path = &type_path.path;
338 if let Some(last_segment) = path.segments.last() {
339 let type_name = last_segment.ident.to_string();
340
341 if type_name == "Option"
343 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
344 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
345 {
346 return format!("{} | null", format_type_for_schema(inner_type));
347 }
348
349 match type_name.as_str() {
351 "String" | "str" => "string".to_string(),
352 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
353 | "u64" | "u128" | "usize" => "number".to_string(),
354 "f32" | "f64" => "number".to_string(),
355 "bool" => "boolean".to_string(),
356 "Vec" => {
357 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
358 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
359 {
360 return format!("{}[]", format_type_for_schema(inner_type));
361 }
362 "array".to_string()
363 }
364 _ => type_name.to_lowercase(),
365 }
366 } else {
367 "unknown".to_string()
368 }
369 }
370 _ => "unknown".to_string(),
371 }
372}
373
374enum PromptAttribute {
376 Skip,
377 Description(String),
378 None,
379}
380
381fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
383 for attr in attrs {
384 if attr.path().is_ident("prompt") {
385 if let Ok(meta_list) = attr.meta.require_list() {
387 let tokens = &meta_list.tokens;
388 let tokens_str = tokens.to_string();
389 if tokens_str == "skip" {
390 return PromptAttribute::Skip;
391 }
392 }
393
394 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
396 return PromptAttribute::Description(lit_str.value());
397 }
398 }
399 }
400 PromptAttribute::None
401}
402
403#[derive(Debug, Default)]
405struct FieldPromptAttrs {
406 skip: bool,
407 rename: Option<String>,
408 format_with: Option<String>,
409 image: bool,
410 example: Option<String>,
411}
412
413fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
415 let mut result = FieldPromptAttrs::default();
416
417 for attr in attrs {
418 if attr.path().is_ident("prompt") {
419 if let Ok(meta_list) = attr.meta.require_list() {
421 if let Ok(metas) =
423 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
424 {
425 for meta in metas {
426 match meta {
427 Meta::Path(path) if path.is_ident("skip") => {
428 result.skip = true;
429 }
430 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
431 if let syn::Expr::Lit(syn::ExprLit {
432 lit: syn::Lit::Str(lit_str),
433 ..
434 }) = nv.value
435 {
436 result.rename = Some(lit_str.value());
437 }
438 }
439 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
440 if let syn::Expr::Lit(syn::ExprLit {
441 lit: syn::Lit::Str(lit_str),
442 ..
443 }) = nv.value
444 {
445 result.format_with = Some(lit_str.value());
446 }
447 }
448 Meta::Path(path) if path.is_ident("image") => {
449 result.image = true;
450 }
451 Meta::NameValue(nv) if nv.path.is_ident("example") => {
452 if let syn::Expr::Lit(syn::ExprLit {
453 lit: syn::Lit::Str(lit_str),
454 ..
455 }) = nv.value
456 {
457 result.example = Some(lit_str.value());
458 }
459 }
460 _ => {}
461 }
462 }
463 } else if meta_list.tokens.to_string() == "skip" {
464 result.skip = true;
466 } else if meta_list.tokens.to_string() == "image" {
467 result.image = true;
469 }
470 }
471 }
472 }
473
474 result
475}
476
477#[proc_macro_derive(ToPrompt, attributes(prompt))]
520pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
521 let input = parse_macro_input!(input as DeriveInput);
522
523 let found_crate =
524 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
525 let crate_path = match found_crate {
526 FoundCrate::Itself => {
527 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
529 quote!(::#ident)
530 }
531 FoundCrate::Name(name) => {
532 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
533 quote!(::#ident)
534 }
535 };
536
537 match &input.data {
539 Data::Enum(data_enum) => {
540 let enum_name = &input.ident;
542 let enum_docs = extract_doc_comments(&input.attrs);
543
544 let mut prompt_lines = Vec::new();
545
546 if !enum_docs.is_empty() {
548 prompt_lines.push(format!("{}: {}", enum_name, enum_docs));
549 } else {
550 prompt_lines.push(format!("{}:", enum_name));
551 }
552 prompt_lines.push(String::new()); prompt_lines.push("Possible values:".to_string());
554
555 for variant in &data_enum.variants {
557 let variant_name = &variant.ident;
558
559 match parse_prompt_attribute(&variant.attrs) {
561 PromptAttribute::Skip => {
562 continue;
564 }
565 PromptAttribute::Description(desc) => {
566 prompt_lines.push(format!("- {}: {}", variant_name, desc));
568 }
569 PromptAttribute::None => {
570 let variant_docs = extract_doc_comments(&variant.attrs);
572 if !variant_docs.is_empty() {
573 prompt_lines.push(format!("- {}: {}", variant_name, variant_docs));
574 } else {
575 prompt_lines.push(format!("- {}", variant_name));
576 }
577 }
578 }
579 }
580
581 let prompt_string = prompt_lines.join("\n");
582 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
583
584 let mut match_arms = Vec::new();
586 for variant in &data_enum.variants {
587 let variant_name = &variant.ident;
588
589 match parse_prompt_attribute(&variant.attrs) {
591 PromptAttribute::Skip => {
592 match_arms.push(quote! {
594 Self::#variant_name => stringify!(#variant_name).to_string()
595 });
596 }
597 PromptAttribute::Description(desc) => {
598 match_arms.push(quote! {
600 Self::#variant_name => format!("{}: {}", stringify!(#variant_name), #desc)
601 });
602 }
603 PromptAttribute::None => {
604 let variant_docs = extract_doc_comments(&variant.attrs);
606 if !variant_docs.is_empty() {
607 match_arms.push(quote! {
608 Self::#variant_name => format!("{}: {}", stringify!(#variant_name), #variant_docs)
609 });
610 } else {
611 match_arms.push(quote! {
612 Self::#variant_name => stringify!(#variant_name).to_string()
613 });
614 }
615 }
616 }
617 }
618
619 let to_prompt_impl = if match_arms.is_empty() {
620 quote! {
622 fn to_prompt(&self) -> String {
623 match *self {}
624 }
625 }
626 } else {
627 quote! {
628 fn to_prompt(&self) -> String {
629 match self {
630 #(#match_arms),*
631 }
632 }
633 }
634 };
635
636 let expanded = quote! {
637 impl #impl_generics #crate_path::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
638 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
639 vec![#crate_path::prompt::PromptPart::Text(self.to_prompt())]
640 }
641
642 #to_prompt_impl
643
644 fn prompt_schema() -> String {
645 #prompt_string.to_string()
646 }
647 }
648 };
649
650 TokenStream::from(expanded)
651 }
652 Data::Struct(data_struct) => {
653 let mut template_attr = None;
655 let mut template_file_attr = None;
656 let mut mode_attr = None;
657 let mut validate_attr = false;
658 let mut type_marker_attr = false;
659
660 for attr in &input.attrs {
661 if attr.path().is_ident("prompt") {
662 if let Ok(metas) =
664 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
665 {
666 for meta in metas {
667 match meta {
668 Meta::NameValue(nv) if nv.path.is_ident("template") => {
669 if let syn::Expr::Lit(expr_lit) = nv.value
670 && let syn::Lit::Str(lit_str) = expr_lit.lit
671 {
672 template_attr = Some(lit_str.value());
673 }
674 }
675 Meta::NameValue(nv) if nv.path.is_ident("template_file") => {
676 if let syn::Expr::Lit(expr_lit) = nv.value
677 && let syn::Lit::Str(lit_str) = expr_lit.lit
678 {
679 template_file_attr = Some(lit_str.value());
680 }
681 }
682 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
683 if let syn::Expr::Lit(expr_lit) = nv.value
684 && let syn::Lit::Str(lit_str) = expr_lit.lit
685 {
686 mode_attr = Some(lit_str.value());
687 }
688 }
689 Meta::NameValue(nv) if nv.path.is_ident("validate") => {
690 if let syn::Expr::Lit(expr_lit) = nv.value
691 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
692 {
693 validate_attr = lit_bool.value();
694 }
695 }
696 Meta::NameValue(nv) if nv.path.is_ident("type_marker") => {
697 if let syn::Expr::Lit(expr_lit) = nv.value
698 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
699 {
700 type_marker_attr = lit_bool.value();
701 }
702 }
703 Meta::Path(path) if path.is_ident("type_marker") => {
704 type_marker_attr = true;
706 }
707 _ => {}
708 }
709 }
710 }
711 }
712 }
713
714 if template_attr.is_some() && template_file_attr.is_some() {
716 return syn::Error::new(
717 input.ident.span(),
718 "The `template` and `template_file` attributes are mutually exclusive. Please use only one.",
719 ).to_compile_error().into();
720 }
721
722 let template_str = if let Some(file_path) = template_file_attr {
724 let mut full_path = None;
728
729 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
731 let is_trybuild = manifest_dir.contains("target/tests/trybuild");
733
734 if !is_trybuild {
735 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
737 if candidate.exists() {
738 full_path = Some(candidate);
739 }
740 } else {
741 if let Some(target_pos) = manifest_dir.find("/target/tests/trybuild") {
747 let workspace_root = &manifest_dir[..target_pos];
748 let original_macros_dir = std::path::Path::new(workspace_root)
750 .join("crates")
751 .join("llm-toolkit-macros");
752
753 let candidate = original_macros_dir.join(&file_path);
754 if candidate.exists() {
755 full_path = Some(candidate);
756 }
757 }
758 }
759 }
760
761 if full_path.is_none() {
763 let candidate = std::path::Path::new(&file_path).to_path_buf();
764 if candidate.exists() {
765 full_path = Some(candidate);
766 }
767 }
768
769 if full_path.is_none()
772 && let Ok(current_dir) = std::env::current_dir()
773 {
774 let mut search_dir = current_dir.as_path();
775 for _ in 0..10 {
777 let macros_dir = search_dir.join("crates/llm-toolkit-macros");
779 if macros_dir.exists() {
780 let candidate = macros_dir.join(&file_path);
781 if candidate.exists() {
782 full_path = Some(candidate);
783 break;
784 }
785 }
786 let candidate = search_dir.join(&file_path);
788 if candidate.exists() {
789 full_path = Some(candidate);
790 break;
791 }
792 if let Some(parent) = search_dir.parent() {
793 search_dir = parent;
794 } else {
795 break;
796 }
797 }
798 }
799
800 if full_path.is_none() {
802 let mut error_msg = format!(
804 "Template file '{}' not found at compile time.\n\nSearched in:",
805 file_path
806 );
807
808 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
809 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
810 error_msg.push_str(&format!("\n - {}", candidate.display()));
811 }
812
813 if let Ok(current_dir) = std::env::current_dir() {
814 let candidate = current_dir.join(&file_path);
815 error_msg.push_str(&format!("\n - {}", candidate.display()));
816 }
817
818 error_msg.push_str("\n\nPlease ensure:");
819 error_msg.push_str("\n 1. The template file exists");
820 error_msg.push_str("\n 2. The path is relative to CARGO_MANIFEST_DIR");
821 error_msg.push_str("\n 3. There are no typos in the path");
822
823 return syn::Error::new(input.ident.span(), error_msg)
824 .to_compile_error()
825 .into();
826 }
827
828 let final_path = full_path.unwrap();
829
830 match std::fs::read_to_string(&final_path) {
832 Ok(content) => Some(content),
833 Err(e) => {
834 return syn::Error::new(
835 input.ident.span(),
836 format!(
837 "Failed to read template file '{}': {}\n\nPath resolved to: {}",
838 file_path,
839 e,
840 final_path.display()
841 ),
842 )
843 .to_compile_error()
844 .into();
845 }
846 }
847 } else {
848 template_attr
849 };
850
851 if validate_attr && let Some(template) = &template_str {
853 let mut env = minijinja::Environment::new();
855 if let Err(e) = env.add_template("validation", template) {
856 let warning_msg =
858 format!("Template validation warning: Invalid Jinja syntax - {}", e);
859 let warning_ident = syn::Ident::new(
860 "TEMPLATE_VALIDATION_WARNING",
861 proc_macro2::Span::call_site(),
862 );
863 let _warning_tokens = quote! {
864 #[deprecated(note = #warning_msg)]
865 const #warning_ident: () = ();
866 let _ = #warning_ident;
867 };
868 eprintln!("cargo:warning={}", warning_msg);
870 }
871
872 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
874 &fields.named
875 } else {
876 panic!("Template validation is only supported for structs with named fields.");
877 };
878
879 let field_names: std::collections::HashSet<String> = fields
880 .iter()
881 .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
882 .collect();
883
884 let placeholders = parse_template_placeholders_with_mode(template);
886
887 for (placeholder_name, _mode) in &placeholders {
888 if placeholder_name != "self" && !field_names.contains(placeholder_name) {
889 let warning_msg = format!(
890 "Template validation warning: Variable '{}' used in template but not found in struct fields",
891 placeholder_name
892 );
893 eprintln!("cargo:warning={}", warning_msg);
894 }
895 }
896 }
897
898 let name = input.ident;
899 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
900
901 let struct_docs = extract_doc_comments(&input.attrs);
903
904 let is_mode_based =
906 mode_attr.is_some() || (template_str.is_none() && struct_docs.contains("mode"));
907
908 let expanded = if is_mode_based || mode_attr.is_some() {
909 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
911 &fields.named
912 } else {
913 panic!(
914 "Mode-based prompt generation is only supported for structs with named fields."
915 );
916 };
917
918 let struct_name_str = name.to_string();
919
920 let has_default = input.attrs.iter().any(|attr| {
922 if attr.path().is_ident("derive")
923 && let Ok(meta_list) = attr.meta.require_list()
924 {
925 let tokens_str = meta_list.tokens.to_string();
926 tokens_str.contains("Default")
927 } else {
928 false
929 }
930 });
931
932 let schema_parts = generate_schema_only_parts(
934 &struct_name_str,
935 &struct_docs,
936 fields,
937 &crate_path,
938 type_marker_attr,
939 );
940
941 let example_parts = generate_example_only_parts(fields, has_default, &crate_path);
943
944 quote! {
945 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
946 fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<#crate_path::prompt::PromptPart> {
947 match mode {
948 "schema_only" => #schema_parts,
949 "example_only" => #example_parts,
950 "full" | _ => {
951 let mut parts = Vec::new();
953
954 let schema_parts = #schema_parts;
956 parts.extend(schema_parts);
957
958 parts.push(#crate_path::prompt::PromptPart::Text("\n### Example".to_string()));
960 parts.push(#crate_path::prompt::PromptPart::Text(
961 format!("Here is an example of a valid `{}` object:", #struct_name_str)
962 ));
963
964 let example_parts = #example_parts;
966 parts.extend(example_parts);
967
968 parts
969 }
970 }
971 }
972
973 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
974 self.to_prompt_parts_with_mode("full")
975 }
976
977 fn to_prompt(&self) -> String {
978 self.to_prompt_parts()
979 .into_iter()
980 .filter_map(|part| match part {
981 #crate_path::prompt::PromptPart::Text(text) => Some(text),
982 _ => None,
983 })
984 .collect::<Vec<_>>()
985 .join("\n")
986 }
987
988 fn prompt_schema() -> String {
989 use std::sync::OnceLock;
990 static SCHEMA_CACHE: OnceLock<String> = OnceLock::new();
991
992 SCHEMA_CACHE.get_or_init(|| {
993 let schema_parts = #schema_parts;
994 schema_parts
995 .into_iter()
996 .filter_map(|part| match part {
997 #crate_path::prompt::PromptPart::Text(text) => Some(text),
998 _ => None,
999 })
1000 .collect::<Vec<_>>()
1001 .join("\n")
1002 }).clone()
1003 }
1004 }
1005 }
1006 } else if let Some(template) = template_str {
1007 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1010 &fields.named
1011 } else {
1012 panic!(
1013 "Template prompt generation is only supported for structs with named fields."
1014 );
1015 };
1016
1017 let placeholders = parse_template_placeholders_with_mode(&template);
1019 let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
1021 mode.is_some()
1022 && fields
1023 .iter()
1024 .any(|f| f.ident.as_ref().unwrap() == field_name)
1025 });
1026
1027 let mut image_field_parts = Vec::new();
1028 for f in fields.iter() {
1029 let field_name = f.ident.as_ref().unwrap();
1030 let attrs = parse_field_prompt_attrs(&f.attrs);
1031
1032 if attrs.image {
1033 image_field_parts.push(quote! {
1035 parts.extend(self.#field_name.to_prompt_parts());
1036 });
1037 }
1038 }
1039
1040 if has_mode_syntax {
1042 let mut context_fields = Vec::new();
1044 let mut modified_template = template.clone();
1045
1046 for (field_name, mode_opt) in &placeholders {
1048 if let Some(mode) = mode_opt {
1049 let unique_key = format!("{}__{}", field_name, mode);
1051
1052 let pattern = format!("{{{{ {}:{} }}}}", field_name, mode);
1054 let replacement = format!("{{{{ {} }}}}", unique_key);
1055 modified_template = modified_template.replace(&pattern, &replacement);
1056
1057 let field_ident =
1059 syn::Ident::new(field_name, proc_macro2::Span::call_site());
1060
1061 context_fields.push(quote! {
1063 context.insert(
1064 #unique_key.to_string(),
1065 minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
1066 );
1067 });
1068 }
1069 }
1070
1071 for field in fields.iter() {
1073 let field_name = field.ident.as_ref().unwrap();
1074 let field_name_str = field_name.to_string();
1075
1076 let has_mode_entry = placeholders
1078 .iter()
1079 .any(|(name, mode)| name == &field_name_str && mode.is_some());
1080
1081 if !has_mode_entry {
1082 let is_primitive = match &field.ty {
1085 syn::Type::Path(type_path) => {
1086 if let Some(segment) = type_path.path.segments.last() {
1087 let type_name = segment.ident.to_string();
1088 matches!(
1089 type_name.as_str(),
1090 "String"
1091 | "str"
1092 | "i8"
1093 | "i16"
1094 | "i32"
1095 | "i64"
1096 | "i128"
1097 | "isize"
1098 | "u8"
1099 | "u16"
1100 | "u32"
1101 | "u64"
1102 | "u128"
1103 | "usize"
1104 | "f32"
1105 | "f64"
1106 | "bool"
1107 | "char"
1108 )
1109 } else {
1110 false
1111 }
1112 }
1113 _ => false,
1114 };
1115
1116 if is_primitive {
1117 context_fields.push(quote! {
1118 context.insert(
1119 #field_name_str.to_string(),
1120 minijinja::Value::from_serialize(&self.#field_name)
1121 );
1122 });
1123 } else {
1124 context_fields.push(quote! {
1126 context.insert(
1127 #field_name_str.to_string(),
1128 minijinja::Value::from(self.#field_name.to_prompt())
1129 );
1130 });
1131 }
1132 }
1133 }
1134
1135 quote! {
1136 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1137 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1138 let mut parts = Vec::new();
1139
1140 #(#image_field_parts)*
1142
1143 let text = {
1145 let mut env = minijinja::Environment::new();
1146 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1147 panic!("Failed to parse template: {}", e)
1148 });
1149
1150 let tmpl = env.get_template("prompt").unwrap();
1151
1152 let mut context = std::collections::HashMap::new();
1153 #(#context_fields)*
1154
1155 tmpl.render(context).unwrap_or_else(|e| {
1156 format!("Failed to render prompt: {}", e)
1157 })
1158 };
1159
1160 if !text.is_empty() {
1161 parts.push(#crate_path::prompt::PromptPart::Text(text));
1162 }
1163
1164 parts
1165 }
1166
1167 fn to_prompt(&self) -> String {
1168 let mut env = minijinja::Environment::new();
1170 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1171 panic!("Failed to parse template: {}", e)
1172 });
1173
1174 let tmpl = env.get_template("prompt").unwrap();
1175
1176 let mut context = std::collections::HashMap::new();
1177 #(#context_fields)*
1178
1179 tmpl.render(context).unwrap_or_else(|e| {
1180 format!("Failed to render prompt: {}", e)
1181 })
1182 }
1183
1184 fn prompt_schema() -> String {
1185 String::new() }
1187 }
1188 }
1189 } else {
1190 quote! {
1192 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1193 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1194 let mut parts = Vec::new();
1195
1196 #(#image_field_parts)*
1198
1199 let text = #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1201 format!("Failed to render prompt: {}", e)
1202 });
1203 if !text.is_empty() {
1204 parts.push(#crate_path::prompt::PromptPart::Text(text));
1205 }
1206
1207 parts
1208 }
1209
1210 fn to_prompt(&self) -> String {
1211 #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1212 format!("Failed to render prompt: {}", e)
1213 })
1214 }
1215
1216 fn prompt_schema() -> String {
1217 String::new() }
1219 }
1220 }
1221 }
1222 } else {
1223 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1226 &fields.named
1227 } else {
1228 panic!(
1229 "Default prompt generation is only supported for structs with named fields."
1230 );
1231 };
1232
1233 let mut text_field_parts = Vec::new();
1235 let mut image_field_parts = Vec::new();
1236
1237 for f in fields.iter() {
1238 let field_name = f.ident.as_ref().unwrap();
1239 let attrs = parse_field_prompt_attrs(&f.attrs);
1240
1241 if attrs.skip {
1243 continue;
1244 }
1245
1246 if attrs.image {
1247 image_field_parts.push(quote! {
1249 parts.extend(self.#field_name.to_prompt_parts());
1250 });
1251 } else {
1252 let key = if let Some(rename) = attrs.rename {
1258 rename
1259 } else {
1260 let doc_comment = extract_doc_comments(&f.attrs);
1261 if !doc_comment.is_empty() {
1262 doc_comment
1263 } else {
1264 field_name.to_string()
1265 }
1266 };
1267
1268 let value_expr = if let Some(format_with) = attrs.format_with {
1270 let func_path: syn::Path =
1272 syn::parse_str(&format_with).unwrap_or_else(|_| {
1273 panic!("Invalid function path: {}", format_with)
1274 });
1275 quote! { #func_path(&self.#field_name) }
1276 } else {
1277 quote! { self.#field_name.to_prompt() }
1278 };
1279
1280 text_field_parts.push(quote! {
1281 text_parts.push(format!("{}: {}", #key, #value_expr));
1282 });
1283 }
1284 }
1285
1286 quote! {
1288 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1289 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1290 let mut parts = Vec::new();
1291
1292 #(#image_field_parts)*
1294
1295 let mut text_parts = Vec::new();
1297 #(#text_field_parts)*
1298
1299 if !text_parts.is_empty() {
1300 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1301 }
1302
1303 parts
1304 }
1305
1306 fn to_prompt(&self) -> String {
1307 let mut text_parts = Vec::new();
1308 #(#text_field_parts)*
1309 text_parts.join("\n")
1310 }
1311
1312 fn prompt_schema() -> String {
1313 String::new() }
1315 }
1316 }
1317 };
1318
1319 TokenStream::from(expanded)
1320 }
1321 Data::Union(_) => {
1322 panic!("`#[derive(ToPrompt)]` is not supported for unions");
1323 }
1324 }
1325}
1326
1327#[derive(Debug, Clone)]
1329struct TargetInfo {
1330 name: String,
1331 template: Option<String>,
1332 field_configs: std::collections::HashMap<String, FieldTargetConfig>,
1333}
1334
1335#[derive(Debug, Clone, Default)]
1337struct FieldTargetConfig {
1338 skip: bool,
1339 rename: Option<String>,
1340 format_with: Option<String>,
1341 image: bool,
1342 include_only: bool, }
1344
1345fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
1347 let mut configs = Vec::new();
1348
1349 for attr in attrs {
1350 if attr.path().is_ident("prompt_for")
1351 && let Ok(meta_list) = attr.meta.require_list()
1352 {
1353 if meta_list.tokens.to_string() == "skip" {
1355 let config = FieldTargetConfig {
1357 skip: true,
1358 ..Default::default()
1359 };
1360 configs.push(("*".to_string(), config));
1361 } else if let Ok(metas) =
1362 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1363 {
1364 let mut target_name = None;
1365 let mut config = FieldTargetConfig::default();
1366
1367 for meta in metas {
1368 match meta {
1369 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1370 if let syn::Expr::Lit(syn::ExprLit {
1371 lit: syn::Lit::Str(lit_str),
1372 ..
1373 }) = nv.value
1374 {
1375 target_name = Some(lit_str.value());
1376 }
1377 }
1378 Meta::Path(path) if path.is_ident("skip") => {
1379 config.skip = true;
1380 }
1381 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
1382 if let syn::Expr::Lit(syn::ExprLit {
1383 lit: syn::Lit::Str(lit_str),
1384 ..
1385 }) = nv.value
1386 {
1387 config.rename = Some(lit_str.value());
1388 }
1389 }
1390 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
1391 if let syn::Expr::Lit(syn::ExprLit {
1392 lit: syn::Lit::Str(lit_str),
1393 ..
1394 }) = nv.value
1395 {
1396 config.format_with = Some(lit_str.value());
1397 }
1398 }
1399 Meta::Path(path) if path.is_ident("image") => {
1400 config.image = true;
1401 }
1402 _ => {}
1403 }
1404 }
1405
1406 if let Some(name) = target_name {
1407 config.include_only = true;
1408 configs.push((name, config));
1409 }
1410 }
1411 }
1412 }
1413
1414 configs
1415}
1416
1417fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
1419 let mut targets = Vec::new();
1420
1421 for attr in attrs {
1422 if attr.path().is_ident("prompt_for")
1423 && let Ok(meta_list) = attr.meta.require_list()
1424 && let Ok(metas) =
1425 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1426 {
1427 let mut target_name = None;
1428 let mut template = None;
1429
1430 for meta in metas {
1431 match meta {
1432 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1433 if let syn::Expr::Lit(syn::ExprLit {
1434 lit: syn::Lit::Str(lit_str),
1435 ..
1436 }) = nv.value
1437 {
1438 target_name = Some(lit_str.value());
1439 }
1440 }
1441 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1442 if let syn::Expr::Lit(syn::ExprLit {
1443 lit: syn::Lit::Str(lit_str),
1444 ..
1445 }) = nv.value
1446 {
1447 template = Some(lit_str.value());
1448 }
1449 }
1450 _ => {}
1451 }
1452 }
1453
1454 if let Some(name) = target_name {
1455 targets.push(TargetInfo {
1456 name,
1457 template,
1458 field_configs: std::collections::HashMap::new(),
1459 });
1460 }
1461 }
1462 }
1463
1464 targets
1465}
1466
1467#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
1468pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
1469 let input = parse_macro_input!(input as DeriveInput);
1470
1471 let found_crate =
1472 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1473 let crate_path = match found_crate {
1474 FoundCrate::Itself => {
1475 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1477 quote!(::#ident)
1478 }
1479 FoundCrate::Name(name) => {
1480 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1481 quote!(::#ident)
1482 }
1483 };
1484
1485 let data_struct = match &input.data {
1487 Data::Struct(data) => data,
1488 _ => {
1489 return syn::Error::new(
1490 input.ident.span(),
1491 "`#[derive(ToPromptSet)]` is only supported for structs",
1492 )
1493 .to_compile_error()
1494 .into();
1495 }
1496 };
1497
1498 let fields = match &data_struct.fields {
1499 syn::Fields::Named(fields) => &fields.named,
1500 _ => {
1501 return syn::Error::new(
1502 input.ident.span(),
1503 "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
1504 )
1505 .to_compile_error()
1506 .into();
1507 }
1508 };
1509
1510 let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
1512
1513 for field in fields.iter() {
1515 let field_name = field.ident.as_ref().unwrap().to_string();
1516 let field_configs = parse_prompt_for_attrs(&field.attrs);
1517
1518 for (target_name, config) in field_configs {
1519 if target_name == "*" {
1520 for target in &mut targets {
1522 target
1523 .field_configs
1524 .entry(field_name.clone())
1525 .or_insert_with(FieldTargetConfig::default)
1526 .skip = config.skip;
1527 }
1528 } else {
1529 let target_exists = targets.iter().any(|t| t.name == target_name);
1531 if !target_exists {
1532 targets.push(TargetInfo {
1534 name: target_name.clone(),
1535 template: None,
1536 field_configs: std::collections::HashMap::new(),
1537 });
1538 }
1539
1540 let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
1541
1542 target.field_configs.insert(field_name.clone(), config);
1543 }
1544 }
1545 }
1546
1547 let mut match_arms = Vec::new();
1549
1550 for target in &targets {
1551 let target_name = &target.name;
1552
1553 if let Some(template_str) = &target.template {
1554 let mut image_parts = Vec::new();
1556
1557 for field in fields.iter() {
1558 let field_name = field.ident.as_ref().unwrap();
1559 let field_name_str = field_name.to_string();
1560
1561 if let Some(config) = target.field_configs.get(&field_name_str)
1562 && config.image
1563 {
1564 image_parts.push(quote! {
1565 parts.extend(self.#field_name.to_prompt_parts());
1566 });
1567 }
1568 }
1569
1570 match_arms.push(quote! {
1571 #target_name => {
1572 let mut parts = Vec::new();
1573
1574 #(#image_parts)*
1575
1576 let text = #crate_path::prompt::render_prompt(#template_str, self)
1577 .map_err(|e| #crate_path::prompt::PromptSetError::RenderFailed {
1578 target: #target_name.to_string(),
1579 source: e,
1580 })?;
1581
1582 if !text.is_empty() {
1583 parts.push(#crate_path::prompt::PromptPart::Text(text));
1584 }
1585
1586 Ok(parts)
1587 }
1588 });
1589 } else {
1590 let mut text_field_parts = Vec::new();
1592 let mut image_field_parts = Vec::new();
1593
1594 for field in fields.iter() {
1595 let field_name = field.ident.as_ref().unwrap();
1596 let field_name_str = field_name.to_string();
1597
1598 let config = target.field_configs.get(&field_name_str);
1600
1601 if let Some(cfg) = config
1603 && cfg.skip
1604 {
1605 continue;
1606 }
1607
1608 let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
1612 let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
1613 .iter()
1614 .any(|(name, _)| name != "*");
1615
1616 if has_any_target_specific_config && !is_explicitly_for_this_target {
1617 continue;
1618 }
1619
1620 if let Some(cfg) = config {
1621 if cfg.image {
1622 image_field_parts.push(quote! {
1623 parts.extend(self.#field_name.to_prompt_parts());
1624 });
1625 } else {
1626 let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
1627
1628 let value_expr = if let Some(format_with) = &cfg.format_with {
1629 match syn::parse_str::<syn::Path>(format_with) {
1631 Ok(func_path) => quote! { #func_path(&self.#field_name) },
1632 Err(_) => {
1633 let error_msg = format!(
1635 "Invalid function path in format_with: '{}'",
1636 format_with
1637 );
1638 quote! {
1639 compile_error!(#error_msg);
1640 String::new()
1641 }
1642 }
1643 }
1644 } else {
1645 quote! { self.#field_name.to_prompt() }
1646 };
1647
1648 text_field_parts.push(quote! {
1649 text_parts.push(format!("{}: {}", #key, #value_expr));
1650 });
1651 }
1652 } else {
1653 text_field_parts.push(quote! {
1655 text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
1656 });
1657 }
1658 }
1659
1660 match_arms.push(quote! {
1661 #target_name => {
1662 let mut parts = Vec::new();
1663
1664 #(#image_field_parts)*
1665
1666 let mut text_parts = Vec::new();
1667 #(#text_field_parts)*
1668
1669 if !text_parts.is_empty() {
1670 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1671 }
1672
1673 Ok(parts)
1674 }
1675 });
1676 }
1677 }
1678
1679 let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
1681
1682 match_arms.push(quote! {
1684 _ => {
1685 let available = vec![#(#target_names.to_string()),*];
1686 Err(#crate_path::prompt::PromptSetError::TargetNotFound {
1687 target: target.to_string(),
1688 available,
1689 })
1690 }
1691 });
1692
1693 let struct_name = &input.ident;
1694 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1695
1696 let expanded = quote! {
1697 impl #impl_generics #crate_path::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
1698 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<#crate_path::prompt::PromptPart>, #crate_path::prompt::PromptSetError> {
1699 match target {
1700 #(#match_arms)*
1701 }
1702 }
1703 }
1704 };
1705
1706 TokenStream::from(expanded)
1707}
1708
1709struct TypeList {
1711 types: Punctuated<syn::Type, Token![,]>,
1712}
1713
1714impl Parse for TypeList {
1715 fn parse(input: ParseStream) -> syn::Result<Self> {
1716 Ok(TypeList {
1717 types: Punctuated::parse_terminated(input)?,
1718 })
1719 }
1720}
1721
1722#[proc_macro]
1746pub fn examples_section(input: TokenStream) -> TokenStream {
1747 let input = parse_macro_input!(input as TypeList);
1748
1749 let found_crate =
1750 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1751 let _crate_path = match found_crate {
1752 FoundCrate::Itself => quote!(crate),
1753 FoundCrate::Name(name) => {
1754 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1755 quote!(::#ident)
1756 }
1757 };
1758
1759 let mut type_sections = Vec::new();
1761
1762 for ty in input.types.iter() {
1763 let type_name_str = quote!(#ty).to_string();
1765
1766 type_sections.push(quote! {
1768 {
1769 let type_name = #type_name_str;
1770 let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
1771 format!("---\n#### `{}`\n{}", type_name, json_example)
1772 }
1773 });
1774 }
1775
1776 let expanded = quote! {
1778 {
1779 let mut sections = Vec::new();
1780 sections.push("---".to_string());
1781 sections.push("### Examples".to_string());
1782 sections.push("".to_string());
1783 sections.push("Here are examples of the data structures you should use.".to_string());
1784 sections.push("".to_string());
1785
1786 #(sections.push(#type_sections);)*
1787
1788 sections.push("---".to_string());
1789
1790 sections.join("\n")
1791 }
1792 };
1793
1794 TokenStream::from(expanded)
1795}
1796
1797fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
1799 for attr in attrs {
1800 if attr.path().is_ident("prompt_for")
1801 && let Ok(meta_list) = attr.meta.require_list()
1802 && let Ok(metas) =
1803 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1804 {
1805 let mut target_type = None;
1806 let mut template = None;
1807
1808 for meta in metas {
1809 match meta {
1810 Meta::NameValue(nv) if nv.path.is_ident("target") => {
1811 if let syn::Expr::Lit(syn::ExprLit {
1812 lit: syn::Lit::Str(lit_str),
1813 ..
1814 }) = nv.value
1815 {
1816 target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
1818 }
1819 }
1820 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1821 if let syn::Expr::Lit(syn::ExprLit {
1822 lit: syn::Lit::Str(lit_str),
1823 ..
1824 }) = nv.value
1825 {
1826 template = Some(lit_str.value());
1827 }
1828 }
1829 _ => {}
1830 }
1831 }
1832
1833 if let (Some(target), Some(tmpl)) = (target_type, template) {
1834 return (target, tmpl);
1835 }
1836 }
1837 }
1838
1839 panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
1840}
1841
1842#[proc_macro_attribute]
1876pub fn define_intent(_attr: TokenStream, item: TokenStream) -> TokenStream {
1877 let input = parse_macro_input!(item as DeriveInput);
1878
1879 let found_crate =
1880 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1881 let crate_path = match found_crate {
1882 FoundCrate::Itself => {
1883 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1885 quote!(::#ident)
1886 }
1887 FoundCrate::Name(name) => {
1888 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1889 quote!(::#ident)
1890 }
1891 };
1892
1893 let enum_data = match &input.data {
1895 Data::Enum(data) => data,
1896 _ => {
1897 return syn::Error::new(
1898 input.ident.span(),
1899 "`#[define_intent]` can only be applied to enums",
1900 )
1901 .to_compile_error()
1902 .into();
1903 }
1904 };
1905
1906 let mut prompt_template = None;
1908 let mut extractor_tag = None;
1909 let mut mode = None;
1910
1911 for attr in &input.attrs {
1912 if attr.path().is_ident("intent")
1913 && let Ok(metas) =
1914 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1915 {
1916 for meta in metas {
1917 match meta {
1918 Meta::NameValue(nv) if nv.path.is_ident("prompt") => {
1919 if let syn::Expr::Lit(syn::ExprLit {
1920 lit: syn::Lit::Str(lit_str),
1921 ..
1922 }) = nv.value
1923 {
1924 prompt_template = Some(lit_str.value());
1925 }
1926 }
1927 Meta::NameValue(nv) if nv.path.is_ident("extractor_tag") => {
1928 if let syn::Expr::Lit(syn::ExprLit {
1929 lit: syn::Lit::Str(lit_str),
1930 ..
1931 }) = nv.value
1932 {
1933 extractor_tag = Some(lit_str.value());
1934 }
1935 }
1936 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
1937 if let syn::Expr::Lit(syn::ExprLit {
1938 lit: syn::Lit::Str(lit_str),
1939 ..
1940 }) = nv.value
1941 {
1942 mode = Some(lit_str.value());
1943 }
1944 }
1945 _ => {}
1946 }
1947 }
1948 }
1949 }
1950
1951 let mode = mode.unwrap_or_else(|| "single".to_string());
1953
1954 if mode != "single" && mode != "multi_tag" {
1956 return syn::Error::new(
1957 input.ident.span(),
1958 "`mode` must be either \"single\" or \"multi_tag\"",
1959 )
1960 .to_compile_error()
1961 .into();
1962 }
1963
1964 let prompt_template = match prompt_template {
1966 Some(p) => p,
1967 None => {
1968 return syn::Error::new(
1969 input.ident.span(),
1970 "`#[intent(...)]` attribute must include `prompt = \"...\"`",
1971 )
1972 .to_compile_error()
1973 .into();
1974 }
1975 };
1976
1977 if mode == "multi_tag" {
1979 let enum_name = &input.ident;
1980 let actions_doc = generate_multi_tag_actions_doc(&enum_data.variants);
1981 return generate_multi_tag_output(
1982 &input,
1983 enum_name,
1984 enum_data,
1985 prompt_template,
1986 actions_doc,
1987 );
1988 }
1989
1990 let extractor_tag = match extractor_tag {
1992 Some(t) => t,
1993 None => {
1994 return syn::Error::new(
1995 input.ident.span(),
1996 "`#[intent(...)]` attribute must include `extractor_tag = \"...\"`",
1997 )
1998 .to_compile_error()
1999 .into();
2000 }
2001 };
2002
2003 let enum_name = &input.ident;
2005 let enum_docs = extract_doc_comments(&input.attrs);
2006
2007 let mut intents_doc_lines = Vec::new();
2008
2009 if !enum_docs.is_empty() {
2011 intents_doc_lines.push(format!("{}: {}", enum_name, enum_docs));
2012 } else {
2013 intents_doc_lines.push(format!("{}:", enum_name));
2014 }
2015 intents_doc_lines.push(String::new()); intents_doc_lines.push("Possible values:".to_string());
2017
2018 for variant in &enum_data.variants {
2020 let variant_name = &variant.ident;
2021 let variant_docs = extract_doc_comments(&variant.attrs);
2022
2023 if !variant_docs.is_empty() {
2024 intents_doc_lines.push(format!("- {}: {}", variant_name, variant_docs));
2025 } else {
2026 intents_doc_lines.push(format!("- {}", variant_name));
2027 }
2028 }
2029
2030 let intents_doc_str = intents_doc_lines.join("\n");
2031
2032 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2034 let user_variables: Vec<String> = placeholders
2035 .iter()
2036 .filter_map(|(name, _)| {
2037 if name != "intents_doc" {
2038 Some(name.clone())
2039 } else {
2040 None
2041 }
2042 })
2043 .collect();
2044
2045 let enum_name_str = enum_name.to_string();
2047 let snake_case_name = to_snake_case(&enum_name_str);
2048 let function_name = syn::Ident::new(
2049 &format!("build_{}_prompt", snake_case_name),
2050 proc_macro2::Span::call_site(),
2051 );
2052
2053 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2055 .iter()
2056 .map(|var| {
2057 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2058 quote! { #ident: &str }
2059 })
2060 .collect();
2061
2062 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2064 .iter()
2065 .map(|var| {
2066 let var_str = var.clone();
2067 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2068 quote! {
2069 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2070 }
2071 })
2072 .collect();
2073
2074 let converted_template = prompt_template.clone();
2076
2077 let extractor_name = syn::Ident::new(
2079 &format!("{}Extractor", enum_name),
2080 proc_macro2::Span::call_site(),
2081 );
2082
2083 let filtered_attrs: Vec<_> = input
2085 .attrs
2086 .iter()
2087 .filter(|attr| !attr.path().is_ident("intent"))
2088 .collect();
2089
2090 let vis = &input.vis;
2092 let generics = &input.generics;
2093 let variants = &enum_data.variants;
2094 let enum_output = quote! {
2095 #(#filtered_attrs)*
2096 #vis enum #enum_name #generics {
2097 #variants
2098 }
2099 };
2100
2101 let expanded = quote! {
2103 #enum_output
2105
2106 pub fn #function_name(#(#function_params),*) -> String {
2108 let mut env = minijinja::Environment::new();
2109 env.add_template("prompt", #converted_template)
2110 .expect("Failed to parse intent prompt template");
2111
2112 let tmpl = env.get_template("prompt").unwrap();
2113
2114 let mut __template_context = std::collections::HashMap::new();
2115
2116 __template_context.insert("intents_doc".to_string(), minijinja::Value::from(#intents_doc_str));
2118
2119 #(#context_insertions)*
2121
2122 tmpl.render(&__template_context)
2123 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2124 }
2125
2126 pub struct #extractor_name;
2128
2129 impl #extractor_name {
2130 pub const EXTRACTOR_TAG: &'static str = #extractor_tag;
2131 }
2132
2133 impl #crate_path::intent::IntentExtractor<#enum_name> for #extractor_name {
2134 fn extract_intent(&self, response: &str) -> Result<#enum_name, #crate_path::intent::IntentExtractionError> {
2135 #crate_path::intent::extract_intent_from_response(response, Self::EXTRACTOR_TAG)
2137 }
2138 }
2139 };
2140
2141 TokenStream::from(expanded)
2142}
2143
2144fn to_snake_case(s: &str) -> String {
2146 let mut result = String::new();
2147 let mut prev_upper = false;
2148
2149 for (i, ch) in s.chars().enumerate() {
2150 if ch.is_uppercase() {
2151 if i > 0 && !prev_upper {
2152 result.push('_');
2153 }
2154 result.push(ch.to_lowercase().next().unwrap());
2155 prev_upper = true;
2156 } else {
2157 result.push(ch);
2158 prev_upper = false;
2159 }
2160 }
2161
2162 result
2163}
2164
2165#[derive(Debug, Default)]
2167struct ActionAttrs {
2168 tag: Option<String>,
2169}
2170
2171fn parse_action_attrs(attrs: &[syn::Attribute]) -> ActionAttrs {
2172 let mut result = ActionAttrs::default();
2173
2174 for attr in attrs {
2175 if attr.path().is_ident("action")
2176 && let Ok(meta_list) = attr.meta.require_list()
2177 && let Ok(metas) =
2178 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2179 {
2180 for meta in metas {
2181 if let Meta::NameValue(nv) = meta
2182 && nv.path.is_ident("tag")
2183 && let syn::Expr::Lit(syn::ExprLit {
2184 lit: syn::Lit::Str(lit_str),
2185 ..
2186 }) = nv.value
2187 {
2188 result.tag = Some(lit_str.value());
2189 }
2190 }
2191 }
2192 }
2193
2194 result
2195}
2196
2197#[derive(Debug, Default)]
2199struct FieldActionAttrs {
2200 is_attribute: bool,
2201 is_inner_text: bool,
2202}
2203
2204fn parse_field_action_attrs(attrs: &[syn::Attribute]) -> FieldActionAttrs {
2205 let mut result = FieldActionAttrs::default();
2206
2207 for attr in attrs {
2208 if attr.path().is_ident("action")
2209 && let Ok(meta_list) = attr.meta.require_list()
2210 {
2211 let tokens_str = meta_list.tokens.to_string();
2212 if tokens_str == "attribute" {
2213 result.is_attribute = true;
2214 } else if tokens_str == "inner_text" {
2215 result.is_inner_text = true;
2216 }
2217 }
2218 }
2219
2220 result
2221}
2222
2223fn generate_multi_tag_actions_doc(
2225 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2226) -> String {
2227 let mut doc_lines = Vec::new();
2228
2229 for variant in variants {
2230 let action_attrs = parse_action_attrs(&variant.attrs);
2231
2232 if let Some(tag) = action_attrs.tag {
2233 let variant_docs = extract_doc_comments(&variant.attrs);
2234
2235 match &variant.fields {
2236 syn::Fields::Unit => {
2237 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2239 }
2240 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
2241 doc_lines.push(format!("- `<{}>...</{}>`: {}", tag, tag, variant_docs));
2243 }
2244 syn::Fields::Named(fields) => {
2245 let mut attrs_str = Vec::new();
2247 let mut has_inner_text = false;
2248
2249 for field in &fields.named {
2250 let field_name = field.ident.as_ref().unwrap();
2251 let field_attrs = parse_field_action_attrs(&field.attrs);
2252
2253 if field_attrs.is_attribute {
2254 attrs_str.push(format!("{}=\"...\"", field_name));
2255 } else if field_attrs.is_inner_text {
2256 has_inner_text = true;
2257 }
2258 }
2259
2260 let attrs_part = if !attrs_str.is_empty() {
2261 format!(" {}", attrs_str.join(" "))
2262 } else {
2263 String::new()
2264 };
2265
2266 if has_inner_text {
2267 doc_lines.push(format!(
2268 "- `<{}{}>...</{}>`: {}",
2269 tag, attrs_part, tag, variant_docs
2270 ));
2271 } else if !attrs_str.is_empty() {
2272 doc_lines.push(format!("- `<{}{} />`: {}", tag, attrs_part, variant_docs));
2273 } else {
2274 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2275 }
2276
2277 for field in &fields.named {
2279 let field_name = field.ident.as_ref().unwrap();
2280 let field_attrs = parse_field_action_attrs(&field.attrs);
2281 let field_docs = extract_doc_comments(&field.attrs);
2282
2283 if field_attrs.is_attribute {
2284 doc_lines
2285 .push(format!(" - `{}` (attribute): {}", field_name, field_docs));
2286 } else if field_attrs.is_inner_text {
2287 doc_lines
2288 .push(format!(" - `{}` (inner_text): {}", field_name, field_docs));
2289 }
2290 }
2291 }
2292 _ => {
2293 }
2295 }
2296 }
2297 }
2298
2299 doc_lines.join("\n")
2300}
2301
2302fn generate_tags_regex(
2304 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2305) -> String {
2306 let mut tag_names = Vec::new();
2307
2308 for variant in variants {
2309 let action_attrs = parse_action_attrs(&variant.attrs);
2310 if let Some(tag) = action_attrs.tag {
2311 tag_names.push(tag);
2312 }
2313 }
2314
2315 if tag_names.is_empty() {
2316 return String::new();
2317 }
2318
2319 let tags_pattern = tag_names.join("|");
2320 format!(
2323 r"(?is)<(?:{})\b[^>]*/>|<(?:{})\b[^>]*>.*?</(?:{})>",
2324 tags_pattern, tags_pattern, tags_pattern
2325 )
2326}
2327
2328fn generate_multi_tag_output(
2330 input: &DeriveInput,
2331 enum_name: &syn::Ident,
2332 enum_data: &syn::DataEnum,
2333 prompt_template: String,
2334 actions_doc: String,
2335) -> TokenStream {
2336 let found_crate =
2337 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2338 let crate_path = match found_crate {
2339 FoundCrate::Itself => {
2340 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2342 quote!(::#ident)
2343 }
2344 FoundCrate::Name(name) => {
2345 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2346 quote!(::#ident)
2347 }
2348 };
2349
2350 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2352 let user_variables: Vec<String> = placeholders
2353 .iter()
2354 .filter_map(|(name, _)| {
2355 if name != "actions_doc" {
2356 Some(name.clone())
2357 } else {
2358 None
2359 }
2360 })
2361 .collect();
2362
2363 let enum_name_str = enum_name.to_string();
2365 let snake_case_name = to_snake_case(&enum_name_str);
2366 let function_name = syn::Ident::new(
2367 &format!("build_{}_prompt", snake_case_name),
2368 proc_macro2::Span::call_site(),
2369 );
2370
2371 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2373 .iter()
2374 .map(|var| {
2375 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2376 quote! { #ident: &str }
2377 })
2378 .collect();
2379
2380 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2382 .iter()
2383 .map(|var| {
2384 let var_str = var.clone();
2385 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2386 quote! {
2387 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2388 }
2389 })
2390 .collect();
2391
2392 let extractor_name = syn::Ident::new(
2394 &format!("{}Extractor", enum_name),
2395 proc_macro2::Span::call_site(),
2396 );
2397
2398 let filtered_attrs: Vec<_> = input
2400 .attrs
2401 .iter()
2402 .filter(|attr| !attr.path().is_ident("intent"))
2403 .collect();
2404
2405 let filtered_variants: Vec<proc_macro2::TokenStream> = enum_data
2407 .variants
2408 .iter()
2409 .map(|variant| {
2410 let variant_name = &variant.ident;
2411 let variant_attrs: Vec<_> = variant
2412 .attrs
2413 .iter()
2414 .filter(|attr| !attr.path().is_ident("action"))
2415 .collect();
2416 let fields = &variant.fields;
2417
2418 let filtered_fields = match fields {
2420 syn::Fields::Named(named_fields) => {
2421 let filtered: Vec<_> = named_fields
2422 .named
2423 .iter()
2424 .map(|field| {
2425 let field_name = &field.ident;
2426 let field_type = &field.ty;
2427 let field_vis = &field.vis;
2428 let filtered_attrs: Vec<_> = field
2429 .attrs
2430 .iter()
2431 .filter(|attr| !attr.path().is_ident("action"))
2432 .collect();
2433 quote! {
2434 #(#filtered_attrs)*
2435 #field_vis #field_name: #field_type
2436 }
2437 })
2438 .collect();
2439 quote! { { #(#filtered,)* } }
2440 }
2441 syn::Fields::Unnamed(unnamed_fields) => {
2442 let types: Vec<_> = unnamed_fields
2443 .unnamed
2444 .iter()
2445 .map(|field| {
2446 let field_type = &field.ty;
2447 quote! { #field_type }
2448 })
2449 .collect();
2450 quote! { (#(#types),*) }
2451 }
2452 syn::Fields::Unit => quote! {},
2453 };
2454
2455 quote! {
2456 #(#variant_attrs)*
2457 #variant_name #filtered_fields
2458 }
2459 })
2460 .collect();
2461
2462 let vis = &input.vis;
2463 let generics = &input.generics;
2464
2465 let parsing_arms = generate_parsing_arms(&enum_data.variants, enum_name);
2467
2468 let tags_regex = generate_tags_regex(&enum_data.variants);
2470
2471 let expanded = quote! {
2472 #(#filtered_attrs)*
2474 #vis enum #enum_name #generics {
2475 #(#filtered_variants),*
2476 }
2477
2478 pub fn #function_name(#(#function_params),*) -> String {
2480 let mut env = minijinja::Environment::new();
2481 env.add_template("prompt", #prompt_template)
2482 .expect("Failed to parse intent prompt template");
2483
2484 let tmpl = env.get_template("prompt").unwrap();
2485
2486 let mut __template_context = std::collections::HashMap::new();
2487
2488 __template_context.insert("actions_doc".to_string(), minijinja::Value::from(#actions_doc));
2490
2491 #(#context_insertions)*
2493
2494 tmpl.render(&__template_context)
2495 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2496 }
2497
2498 pub struct #extractor_name;
2500
2501 impl #extractor_name {
2502 fn parse_single_action(&self, text: &str) -> Option<#enum_name> {
2503 use ::quick_xml::events::Event;
2504 use ::quick_xml::Reader;
2505
2506 let mut actions = Vec::new();
2507 let mut reader = Reader::from_str(text);
2508 reader.config_mut().trim_text(true);
2509
2510 let mut buf = Vec::new();
2511
2512 loop {
2513 match reader.read_event_into(&mut buf) {
2514 Ok(Event::Start(e)) => {
2515 let owned_e = e.into_owned();
2516 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2517 let is_empty = false;
2518
2519 #parsing_arms
2520 }
2521 Ok(Event::Empty(e)) => {
2522 let owned_e = e.into_owned();
2523 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2524 let is_empty = true;
2525
2526 #parsing_arms
2527 }
2528 Ok(Event::Eof) => break,
2529 Err(_) => {
2530 break;
2532 }
2533 _ => {}
2534 }
2535 buf.clear();
2536 }
2537
2538 actions.into_iter().next()
2539 }
2540
2541 pub fn extract_actions(&self, text: &str) -> Result<Vec<#enum_name>, #crate_path::intent::IntentError> {
2542 use ::quick_xml::events::Event;
2543 use ::quick_xml::Reader;
2544
2545 let mut actions = Vec::new();
2546 let mut reader = Reader::from_str(text);
2547 reader.config_mut().trim_text(true);
2548
2549 let mut buf = Vec::new();
2550
2551 loop {
2552 match reader.read_event_into(&mut buf) {
2553 Ok(Event::Start(e)) => {
2554 let owned_e = e.into_owned();
2555 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2556 let is_empty = false;
2557
2558 #parsing_arms
2559 }
2560 Ok(Event::Empty(e)) => {
2561 let owned_e = e.into_owned();
2562 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2563 let is_empty = true;
2564
2565 #parsing_arms
2566 }
2567 Ok(Event::Eof) => break,
2568 Err(_) => {
2569 break;
2571 }
2572 _ => {}
2573 }
2574 buf.clear();
2575 }
2576
2577 Ok(actions)
2578 }
2579
2580 pub fn transform_actions<F>(&self, text: &str, mut transformer: F) -> String
2581 where
2582 F: FnMut(#enum_name) -> String,
2583 {
2584 use ::regex::Regex;
2585
2586 let regex_pattern = #tags_regex;
2587 if regex_pattern.is_empty() {
2588 return text.to_string();
2589 }
2590
2591 let re = Regex::new(®ex_pattern).unwrap_or_else(|e| {
2592 panic!("Failed to compile regex for action tags: {}", e);
2593 });
2594
2595 re.replace_all(text, |caps: &::regex::Captures| {
2596 let matched = caps.get(0).map(|m| m.as_str()).unwrap_or("");
2597
2598 if let Some(action) = self.parse_single_action(matched) {
2600 transformer(action)
2601 } else {
2602 matched.to_string()
2604 }
2605 }).to_string()
2606 }
2607
2608 pub fn strip_actions(&self, text: &str) -> String {
2609 self.transform_actions(text, |_| String::new())
2610 }
2611 }
2612 };
2613
2614 TokenStream::from(expanded)
2615}
2616
2617fn generate_parsing_arms(
2619 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2620 enum_name: &syn::Ident,
2621) -> proc_macro2::TokenStream {
2622 let mut arms = Vec::new();
2623
2624 for variant in variants {
2625 let variant_name = &variant.ident;
2626 let action_attrs = parse_action_attrs(&variant.attrs);
2627
2628 if let Some(tag) = action_attrs.tag {
2629 match &variant.fields {
2630 syn::Fields::Unit => {
2631 arms.push(quote! {
2633 if &tag_name == #tag {
2634 actions.push(#enum_name::#variant_name);
2635 }
2636 });
2637 }
2638 syn::Fields::Unnamed(_fields) => {
2639 arms.push(quote! {
2641 if &tag_name == #tag && !is_empty {
2642 match reader.read_text(owned_e.name()) {
2644 Ok(text) => {
2645 actions.push(#enum_name::#variant_name(text.to_string()));
2646 }
2647 Err(_) => {
2648 actions.push(#enum_name::#variant_name(String::new()));
2650 }
2651 }
2652 }
2653 });
2654 }
2655 syn::Fields::Named(fields) => {
2656 let mut field_names = Vec::new();
2658 let mut has_inner_text_field = None;
2659
2660 for field in &fields.named {
2661 let field_name = field.ident.as_ref().unwrap();
2662 let field_attrs = parse_field_action_attrs(&field.attrs);
2663
2664 if field_attrs.is_attribute {
2665 field_names.push(field_name.clone());
2666 } else if field_attrs.is_inner_text {
2667 has_inner_text_field = Some(field_name.clone());
2668 }
2669 }
2670
2671 if let Some(inner_text_field) = has_inner_text_field {
2672 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2675 quote! {
2676 let mut #field_name = String::new();
2677 for attr in owned_e.attributes() {
2678 if let Ok(attr) = attr {
2679 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2680 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2681 break;
2682 }
2683 }
2684 }
2685 }
2686 }).collect();
2687
2688 arms.push(quote! {
2689 if &tag_name == #tag {
2690 #(#attr_extractions)*
2691
2692 if is_empty {
2694 let #inner_text_field = String::new();
2695 actions.push(#enum_name::#variant_name {
2696 #(#field_names,)*
2697 #inner_text_field,
2698 });
2699 } else {
2700 match reader.read_text(owned_e.name()) {
2702 Ok(text) => {
2703 let #inner_text_field = text.to_string();
2704 actions.push(#enum_name::#variant_name {
2705 #(#field_names,)*
2706 #inner_text_field,
2707 });
2708 }
2709 Err(_) => {
2710 let #inner_text_field = String::new();
2712 actions.push(#enum_name::#variant_name {
2713 #(#field_names,)*
2714 #inner_text_field,
2715 });
2716 }
2717 }
2718 }
2719 }
2720 });
2721 } else {
2722 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2724 quote! {
2725 let mut #field_name = String::new();
2726 for attr in owned_e.attributes() {
2727 if let Ok(attr) = attr {
2728 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2729 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2730 break;
2731 }
2732 }
2733 }
2734 }
2735 }).collect();
2736
2737 arms.push(quote! {
2738 if &tag_name == #tag {
2739 #(#attr_extractions)*
2740 actions.push(#enum_name::#variant_name {
2741 #(#field_names),*
2742 });
2743 }
2744 });
2745 }
2746 }
2747 }
2748 }
2749 }
2750
2751 quote! {
2752 #(#arms)*
2753 }
2754}
2755
2756#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
2758pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
2759 let input = parse_macro_input!(input as DeriveInput);
2760
2761 let found_crate =
2762 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2763 let crate_path = match found_crate {
2764 FoundCrate::Itself => {
2765 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2767 quote!(::#ident)
2768 }
2769 FoundCrate::Name(name) => {
2770 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2771 quote!(::#ident)
2772 }
2773 };
2774
2775 let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
2777
2778 let struct_name = &input.ident;
2779 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2780
2781 let placeholders = parse_template_placeholders_with_mode(&template);
2783
2784 let mut converted_template = template.clone();
2786 let mut context_fields = Vec::new();
2787
2788 let fields = match &input.data {
2790 Data::Struct(data_struct) => match &data_struct.fields {
2791 syn::Fields::Named(fields) => &fields.named,
2792 _ => panic!("ToPromptFor is only supported for structs with named fields"),
2793 },
2794 _ => panic!("ToPromptFor is only supported for structs"),
2795 };
2796
2797 let has_mode_support = input.attrs.iter().any(|attr| {
2799 if attr.path().is_ident("prompt")
2800 && let Ok(metas) =
2801 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2802 {
2803 for meta in metas {
2804 if let Meta::NameValue(nv) = meta
2805 && nv.path.is_ident("mode")
2806 {
2807 return true;
2808 }
2809 }
2810 }
2811 false
2812 });
2813
2814 for (placeholder_name, mode_opt) in &placeholders {
2816 if placeholder_name == "self" {
2817 if let Some(specific_mode) = mode_opt {
2818 let unique_key = format!("self__{}", specific_mode);
2820
2821 let pattern = format!("{{{{ self:{} }}}}", specific_mode);
2823 let replacement = format!("{{{{ {} }}}}", unique_key);
2824 converted_template = converted_template.replace(&pattern, &replacement);
2825
2826 context_fields.push(quote! {
2828 context.insert(
2829 #unique_key.to_string(),
2830 minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
2831 );
2832 });
2833 } else {
2834 if has_mode_support {
2837 context_fields.push(quote! {
2839 context.insert(
2840 "self".to_string(),
2841 minijinja::Value::from(self.to_prompt_with_mode(mode))
2842 );
2843 });
2844 } else {
2845 context_fields.push(quote! {
2847 context.insert(
2848 "self".to_string(),
2849 minijinja::Value::from(self.to_prompt())
2850 );
2851 });
2852 }
2853 }
2854 } else {
2855 let field_exists = fields.iter().any(|f| {
2858 f.ident
2859 .as_ref()
2860 .is_some_and(|ident| ident == placeholder_name)
2861 });
2862
2863 if field_exists {
2864 let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
2865
2866 context_fields.push(quote! {
2870 context.insert(
2871 #placeholder_name.to_string(),
2872 minijinja::Value::from_serialize(&self.#field_ident)
2873 );
2874 });
2875 }
2876 }
2878 }
2879
2880 let expanded = quote! {
2881 impl #impl_generics #crate_path::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
2882 where
2883 #target_type: serde::Serialize,
2884 {
2885 fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
2886 let mut env = minijinja::Environment::new();
2888 env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
2889 panic!("Failed to parse template: {}", e)
2890 });
2891
2892 let tmpl = env.get_template("prompt").unwrap();
2893
2894 let mut context = std::collections::HashMap::new();
2896 context.insert(
2898 "self".to_string(),
2899 minijinja::Value::from_serialize(self)
2900 );
2901 context.insert(
2903 "target".to_string(),
2904 minijinja::Value::from_serialize(target)
2905 );
2906 #(#context_fields)*
2907
2908 tmpl.render(context).unwrap_or_else(|e| {
2910 format!("Failed to render prompt: {}", e)
2911 })
2912 }
2913 }
2914 };
2915
2916 TokenStream::from(expanded)
2917}
2918
2919struct AgentAttrs {
2925 expertise: Option<String>,
2926 output: Option<syn::Type>,
2927 backend: Option<String>,
2928 model: Option<String>,
2929 inner: Option<String>,
2930 default_inner: Option<String>,
2931 max_retries: Option<u32>,
2932 profile: Option<String>,
2933}
2934
2935impl Parse for AgentAttrs {
2936 fn parse(input: ParseStream) -> syn::Result<Self> {
2937 let mut expertise = None;
2938 let mut output = None;
2939 let mut backend = None;
2940 let mut model = None;
2941 let mut inner = None;
2942 let mut default_inner = None;
2943 let mut max_retries = None;
2944 let mut profile = None;
2945
2946 let pairs = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
2947
2948 for meta in pairs {
2949 match meta {
2950 Meta::NameValue(nv) if nv.path.is_ident("expertise") => {
2951 if let syn::Expr::Lit(syn::ExprLit {
2952 lit: syn::Lit::Str(lit_str),
2953 ..
2954 }) = &nv.value
2955 {
2956 expertise = Some(lit_str.value());
2957 }
2958 }
2959 Meta::NameValue(nv) if nv.path.is_ident("output") => {
2960 if let syn::Expr::Lit(syn::ExprLit {
2961 lit: syn::Lit::Str(lit_str),
2962 ..
2963 }) = &nv.value
2964 {
2965 let ty: syn::Type = syn::parse_str(&lit_str.value())?;
2966 output = Some(ty);
2967 }
2968 }
2969 Meta::NameValue(nv) if nv.path.is_ident("backend") => {
2970 if let syn::Expr::Lit(syn::ExprLit {
2971 lit: syn::Lit::Str(lit_str),
2972 ..
2973 }) = &nv.value
2974 {
2975 backend = Some(lit_str.value());
2976 }
2977 }
2978 Meta::NameValue(nv) if nv.path.is_ident("model") => {
2979 if let syn::Expr::Lit(syn::ExprLit {
2980 lit: syn::Lit::Str(lit_str),
2981 ..
2982 }) = &nv.value
2983 {
2984 model = Some(lit_str.value());
2985 }
2986 }
2987 Meta::NameValue(nv) if nv.path.is_ident("inner") => {
2988 if let syn::Expr::Lit(syn::ExprLit {
2989 lit: syn::Lit::Str(lit_str),
2990 ..
2991 }) = &nv.value
2992 {
2993 inner = Some(lit_str.value());
2994 }
2995 }
2996 Meta::NameValue(nv) if nv.path.is_ident("default_inner") => {
2997 if let syn::Expr::Lit(syn::ExprLit {
2998 lit: syn::Lit::Str(lit_str),
2999 ..
3000 }) = &nv.value
3001 {
3002 default_inner = Some(lit_str.value());
3003 }
3004 }
3005 Meta::NameValue(nv) if nv.path.is_ident("max_retries") => {
3006 if let syn::Expr::Lit(syn::ExprLit {
3007 lit: syn::Lit::Int(lit_int),
3008 ..
3009 }) = &nv.value
3010 {
3011 max_retries = Some(lit_int.base10_parse()?);
3012 }
3013 }
3014 Meta::NameValue(nv) if nv.path.is_ident("profile") => {
3015 if let syn::Expr::Lit(syn::ExprLit {
3016 lit: syn::Lit::Str(lit_str),
3017 ..
3018 }) = &nv.value
3019 {
3020 profile = Some(lit_str.value());
3021 }
3022 }
3023 _ => {}
3024 }
3025 }
3026
3027 Ok(AgentAttrs {
3028 expertise,
3029 output,
3030 backend,
3031 model,
3032 inner,
3033 default_inner,
3034 max_retries,
3035 profile,
3036 })
3037 }
3038}
3039
3040fn parse_agent_attrs(attrs: &[syn::Attribute]) -> syn::Result<AgentAttrs> {
3042 for attr in attrs {
3043 if attr.path().is_ident("agent") {
3044 return attr.parse_args::<AgentAttrs>();
3045 }
3046 }
3047
3048 Ok(AgentAttrs {
3049 expertise: None,
3050 output: None,
3051 backend: None,
3052 model: None,
3053 inner: None,
3054 default_inner: None,
3055 max_retries: None,
3056 profile: None,
3057 })
3058}
3059
3060fn generate_backend_constructors(
3062 struct_name: &syn::Ident,
3063 backend: &str,
3064 _model: Option<&str>,
3065 _profile: Option<&str>,
3066 crate_path: &proc_macro2::TokenStream,
3067) -> proc_macro2::TokenStream {
3068 match backend {
3069 "claude" => {
3070 quote! {
3071 impl #struct_name {
3072 pub fn with_claude() -> Self {
3074 Self::new(#crate_path::agent::impls::ClaudeCodeAgent::new())
3075 }
3076
3077 pub fn with_claude_model(model: &str) -> Self {
3079 Self::new(
3080 #crate_path::agent::impls::ClaudeCodeAgent::new()
3081 .with_model_str(model)
3082 )
3083 }
3084 }
3085 }
3086 }
3087 "gemini" => {
3088 quote! {
3089 impl #struct_name {
3090 pub fn with_gemini() -> Self {
3092 Self::new(#crate_path::agent::impls::GeminiAgent::new())
3093 }
3094
3095 pub fn with_gemini_model(model: &str) -> Self {
3097 Self::new(
3098 #crate_path::agent::impls::GeminiAgent::new()
3099 .with_model_str(model)
3100 )
3101 }
3102 }
3103 }
3104 }
3105 _ => quote! {},
3106 }
3107}
3108
3109fn generate_default_impl(
3111 struct_name: &syn::Ident,
3112 backend: &str,
3113 model: Option<&str>,
3114 profile: Option<&str>,
3115 crate_path: &proc_macro2::TokenStream,
3116) -> proc_macro2::TokenStream {
3117 let profile_expr = if let Some(profile_str) = profile {
3119 match profile_str.to_lowercase().as_str() {
3120 "creative" => quote! { #crate_path::agent::ExecutionProfile::Creative },
3121 "balanced" => quote! { #crate_path::agent::ExecutionProfile::Balanced },
3122 "deterministic" => quote! { #crate_path::agent::ExecutionProfile::Deterministic },
3123 _ => quote! { #crate_path::agent::ExecutionProfile::Balanced }, }
3125 } else {
3126 quote! { #crate_path::agent::ExecutionProfile::default() }
3127 };
3128
3129 let agent_init = match backend {
3130 "gemini" => {
3131 let mut builder = quote! { #crate_path::agent::impls::GeminiAgent::new() };
3132
3133 if let Some(model_str) = model {
3134 builder = quote! { #builder.with_model_str(#model_str) };
3135 }
3136
3137 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3138 builder
3139 }
3140 _ => {
3141 let mut builder = quote! { #crate_path::agent::impls::ClaudeCodeAgent::new() };
3143
3144 if let Some(model_str) = model {
3145 builder = quote! { #builder.with_model_str(#model_str) };
3146 }
3147
3148 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3149 builder
3150 }
3151 };
3152
3153 quote! {
3154 impl Default for #struct_name {
3155 fn default() -> Self {
3156 Self::new(#agent_init)
3157 }
3158 }
3159 }
3160}
3161
3162#[proc_macro_derive(Agent, attributes(agent))]
3171pub fn derive_agent(input: TokenStream) -> TokenStream {
3172 let input = parse_macro_input!(input as DeriveInput);
3173 let struct_name = &input.ident;
3174
3175 let agent_attrs = match parse_agent_attrs(&input.attrs) {
3177 Ok(attrs) => attrs,
3178 Err(e) => return e.to_compile_error().into(),
3179 };
3180
3181 let expertise = agent_attrs
3182 .expertise
3183 .unwrap_or_else(|| String::from("general AI assistant"));
3184 let output_type = agent_attrs
3185 .output
3186 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3187 let backend = agent_attrs
3188 .backend
3189 .unwrap_or_else(|| String::from("claude"));
3190 let model = agent_attrs.model;
3191 let _profile = agent_attrs.profile; let max_retries = agent_attrs.max_retries.unwrap_or(3); let found_crate =
3196 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3197 let crate_path = match found_crate {
3198 FoundCrate::Itself => {
3199 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3201 quote!(::#ident)
3202 }
3203 FoundCrate::Name(name) => {
3204 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3205 quote!(::#ident)
3206 }
3207 };
3208
3209 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
3210
3211 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3213 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3214
3215 let enhanced_expertise = if is_string_output {
3217 quote! { #expertise }
3219 } else {
3220 let type_name = quote!(#output_type).to_string();
3222 quote! {
3223 {
3224 use std::sync::OnceLock;
3225 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3226
3227 EXPERTISE_CACHE.get_or_init(|| {
3228 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3230
3231 if schema.is_empty() {
3232 format!(
3234 concat!(
3235 #expertise,
3236 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3237 "Do not include any text outside the JSON object."
3238 ),
3239 #type_name
3240 )
3241 } else {
3242 format!(
3244 concat!(
3245 #expertise,
3246 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3247 ),
3248 schema
3249 )
3250 }
3251 }).as_str()
3252 }
3253 }
3254 };
3255
3256 let agent_init = match backend.as_str() {
3258 "gemini" => {
3259 if let Some(model_str) = model {
3260 quote! {
3261 use #crate_path::agent::impls::GeminiAgent;
3262 let agent = GeminiAgent::new().with_model_str(#model_str);
3263 }
3264 } else {
3265 quote! {
3266 use #crate_path::agent::impls::GeminiAgent;
3267 let agent = GeminiAgent::new();
3268 }
3269 }
3270 }
3271 "claude" => {
3272 if let Some(model_str) = model {
3273 quote! {
3274 use #crate_path::agent::impls::ClaudeCodeAgent;
3275 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3276 }
3277 } else {
3278 quote! {
3279 use #crate_path::agent::impls::ClaudeCodeAgent;
3280 let agent = ClaudeCodeAgent::new();
3281 }
3282 }
3283 }
3284 _ => {
3285 if let Some(model_str) = model {
3287 quote! {
3288 use #crate_path::agent::impls::ClaudeCodeAgent;
3289 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3290 }
3291 } else {
3292 quote! {
3293 use #crate_path::agent::impls::ClaudeCodeAgent;
3294 let agent = ClaudeCodeAgent::new();
3295 }
3296 }
3297 }
3298 };
3299
3300 let expanded = quote! {
3301 #[async_trait::async_trait]
3302 impl #impl_generics #crate_path::agent::Agent for #struct_name #ty_generics #where_clause {
3303 type Output = #output_type;
3304
3305 fn expertise(&self) -> &str {
3306 #enhanced_expertise
3307 }
3308
3309 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3310 #agent_init
3312
3313 let max_retries: u32 = #max_retries;
3315 let mut attempts = 0u32;
3316
3317 loop {
3318 attempts += 1;
3319
3320 let result = async {
3322 let response = agent.execute(intent.clone()).await?;
3323
3324 let json_str = #crate_path::extract_json(&response)
3326 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))?;
3327
3328 serde_json::from_str::<Self::Output>(&json_str)
3330 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))
3331 }.await;
3332
3333 match result {
3334 Ok(output) => return Ok(output),
3335 Err(e) if e.is_retryable() && attempts < max_retries => {
3336 log::warn!(
3338 "Agent execution failed (attempt {}/{}): {}. Retrying...",
3339 attempts,
3340 max_retries,
3341 e
3342 );
3343
3344 let delay_ms = 100 * attempts as u64;
3346 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
3347
3348 continue;
3350 }
3351 Err(e) => {
3352 if attempts > 1 {
3353 log::error!(
3354 "Agent execution failed after {} attempts: {}",
3355 attempts,
3356 e
3357 );
3358 }
3359 return Err(e);
3360 }
3361 }
3362 }
3363 }
3364
3365 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3366 #agent_init
3368 agent.is_available().await
3369 }
3370 }
3371 };
3372
3373 TokenStream::from(expanded)
3374}
3375
3376#[proc_macro_attribute]
3391pub fn agent(attr: TokenStream, item: TokenStream) -> TokenStream {
3392 let agent_attrs = match syn::parse::<AgentAttrs>(attr) {
3394 Ok(attrs) => attrs,
3395 Err(e) => return e.to_compile_error().into(),
3396 };
3397
3398 let input = parse_macro_input!(item as DeriveInput);
3400 let struct_name = &input.ident;
3401 let vis = &input.vis;
3402
3403 let expertise = agent_attrs
3404 .expertise
3405 .unwrap_or_else(|| String::from("general AI assistant"));
3406 let output_type = agent_attrs
3407 .output
3408 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3409 let backend = agent_attrs
3410 .backend
3411 .unwrap_or_else(|| String::from("claude"));
3412 let model = agent_attrs.model;
3413 let profile = agent_attrs.profile;
3414
3415 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3417 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3418
3419 let found_crate =
3421 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3422 let crate_path = match found_crate {
3423 FoundCrate::Itself => {
3424 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3425 quote!(::#ident)
3426 }
3427 FoundCrate::Name(name) => {
3428 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3429 quote!(::#ident)
3430 }
3431 };
3432
3433 let inner_generic_name = agent_attrs.inner.unwrap_or_else(|| String::from("A"));
3435 let inner_generic_ident = syn::Ident::new(&inner_generic_name, proc_macro2::Span::call_site());
3436
3437 let default_agent_type = if let Some(ref custom_type) = agent_attrs.default_inner {
3439 let type_path: syn::Type =
3441 syn::parse_str(custom_type).expect("default_inner must be a valid type path");
3442 quote! { #type_path }
3443 } else {
3444 match backend.as_str() {
3446 "gemini" => quote! { #crate_path::agent::impls::GeminiAgent },
3447 _ => quote! { #crate_path::agent::impls::ClaudeCodeAgent },
3448 }
3449 };
3450
3451 let struct_def = quote! {
3453 #vis struct #struct_name<#inner_generic_ident = #default_agent_type> {
3454 inner: #inner_generic_ident,
3455 }
3456 };
3457
3458 let constructors = quote! {
3460 impl<#inner_generic_ident> #struct_name<#inner_generic_ident> {
3461 pub fn new(inner: #inner_generic_ident) -> Self {
3463 Self { inner }
3464 }
3465 }
3466 };
3467
3468 let (backend_constructors, default_impl) = if agent_attrs.default_inner.is_some() {
3470 let default_impl = quote! {
3472 impl Default for #struct_name {
3473 fn default() -> Self {
3474 Self {
3475 inner: <#default_agent_type as Default>::default(),
3476 }
3477 }
3478 }
3479 };
3480 (quote! {}, default_impl)
3481 } else {
3482 let backend_constructors = generate_backend_constructors(
3484 struct_name,
3485 &backend,
3486 model.as_deref(),
3487 profile.as_deref(),
3488 &crate_path,
3489 );
3490 let default_impl = generate_default_impl(
3491 struct_name,
3492 &backend,
3493 model.as_deref(),
3494 profile.as_deref(),
3495 &crate_path,
3496 );
3497 (backend_constructors, default_impl)
3498 };
3499
3500 let enhanced_expertise = if is_string_output {
3502 quote! { #expertise }
3504 } else {
3505 let type_name = quote!(#output_type).to_string();
3507 quote! {
3508 {
3509 use std::sync::OnceLock;
3510 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3511
3512 EXPERTISE_CACHE.get_or_init(|| {
3513 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3515
3516 if schema.is_empty() {
3517 format!(
3519 concat!(
3520 #expertise,
3521 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3522 "Do not include any text outside the JSON object."
3523 ),
3524 #type_name
3525 )
3526 } else {
3527 format!(
3529 concat!(
3530 #expertise,
3531 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3532 ),
3533 schema
3534 )
3535 }
3536 }).as_str()
3537 }
3538 }
3539 };
3540
3541 let agent_impl = quote! {
3543 #[async_trait::async_trait]
3544 impl<#inner_generic_ident> #crate_path::agent::Agent for #struct_name<#inner_generic_ident>
3545 where
3546 #inner_generic_ident: #crate_path::agent::Agent<Output = String>,
3547 {
3548 type Output = #output_type;
3549
3550 fn expertise(&self) -> &str {
3551 #enhanced_expertise
3552 }
3553
3554 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3555 let enhanced_payload = intent.prepend_text(self.expertise());
3557
3558 let response = self.inner.execute(enhanced_payload).await?;
3560
3561 let json_str = #crate_path::extract_json(&response)
3563 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))?;
3564
3565 serde_json::from_str(&json_str)
3567 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))
3568 }
3569
3570 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3571 self.inner.is_available().await
3572 }
3573 }
3574 };
3575
3576 let expanded = quote! {
3577 #struct_def
3578 #constructors
3579 #backend_constructors
3580 #default_impl
3581 #agent_impl
3582 };
3583
3584 TokenStream::from(expanded)
3585}
3586
3587#[proc_macro_derive(TypeMarker)]
3609pub fn derive_type_marker(input: TokenStream) -> TokenStream {
3610 let input = parse_macro_input!(input as DeriveInput);
3611 let struct_name = &input.ident;
3612 let type_name_str = struct_name.to_string();
3613
3614 let found_crate =
3616 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3617 let crate_path = match found_crate {
3618 FoundCrate::Itself => {
3619 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3620 quote!(::#ident)
3621 }
3622 FoundCrate::Name(name) => {
3623 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3624 quote!(::#ident)
3625 }
3626 };
3627
3628 let expanded = quote! {
3629 impl #crate_path::orchestrator::TypeMarker for #struct_name {
3630 const TYPE_NAME: &'static str = #type_name_str;
3631 }
3632 };
3633
3634 TokenStream::from(expanded)
3635}