1use proc_macro::TokenStream;
2use proc_macro_crate::{FoundCrate, crate_name};
3use quote::quote;
4use regex::Regex;
5use syn::{
6 Data, DeriveInput, Meta, Token,
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 punctuated::Punctuated,
10};
11
12fn parse_template_placeholders_with_mode(template: &str) -> Vec<(String, Option<String>)> {
15 let mut placeholders = Vec::new();
16 let mut seen_fields = std::collections::HashSet::new();
17
18 let mode_pattern = Regex::new(r"\{\{\s*(\w+)\s*:\s*(\w+)\s*\}\}").unwrap();
20 for cap in mode_pattern.captures_iter(template) {
21 let field_name = cap[1].to_string();
22 let mode = cap[2].to_string();
23 placeholders.push((field_name.clone(), Some(mode)));
24 seen_fields.insert(field_name);
25 }
26
27 let standard_pattern = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
29 for cap in standard_pattern.captures_iter(template) {
30 let field_name = cap[1].to_string();
31 if !seen_fields.contains(&field_name) {
33 placeholders.push((field_name, None));
34 }
35 }
36
37 placeholders
38}
39
40fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
42 attrs
43 .iter()
44 .filter_map(|attr| {
45 if attr.path().is_ident("doc")
46 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
47 && let syn::Expr::Lit(syn::ExprLit {
48 lit: syn::Lit::Str(lit_str),
49 ..
50 }) = &meta_name_value.value
51 {
52 return Some(lit_str.value());
53 }
54 None
55 })
56 .map(|s| s.trim().to_string())
57 .collect::<Vec<_>>()
58 .join(" ")
59}
60
61fn generate_example_only_parts(
63 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
64 has_default: bool,
65 crate_path: &proc_macro2::TokenStream,
66) -> proc_macro2::TokenStream {
67 let mut field_values = Vec::new();
68
69 for field in fields.iter() {
70 let field_name = field.ident.as_ref().unwrap();
71 let field_name_str = field_name.to_string();
72 let attrs = parse_field_prompt_attrs(&field.attrs);
73
74 if field_name_str == "__type" {
78 continue;
79 }
80
81 if attrs.skip {
83 continue;
84 }
85
86 if let Some(example) = attrs.example {
88 field_values.push(quote! {
90 json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
91 });
92 } else if has_default {
93 field_values.push(quote! {
95 let default_value = serde_json::to_value(&default_instance.#field_name)
96 .unwrap_or(serde_json::Value::Null);
97 json_obj.insert(#field_name_str.to_string(), default_value);
98 });
99 } else {
100 field_values.push(quote! {
102 let value = serde_json::to_value(&self.#field_name)
103 .unwrap_or(serde_json::Value::Null);
104 json_obj.insert(#field_name_str.to_string(), value);
105 });
106 }
107 }
108
109 if has_default {
110 quote! {
111 {
112 let default_instance = Self::default();
113 let mut json_obj = serde_json::Map::new();
114 #(#field_values)*
115 let json_value = serde_json::Value::Object(json_obj);
116 let json_str = serde_json::to_string_pretty(&json_value)
117 .unwrap_or_else(|_| "{}".to_string());
118 vec![#crate_path::prompt::PromptPart::Text(json_str)]
119 }
120 }
121 } else {
122 quote! {
123 {
124 let mut json_obj = serde_json::Map::new();
125 #(#field_values)*
126 let json_value = serde_json::Value::Object(json_obj);
127 let json_str = serde_json::to_string_pretty(&json_value)
128 .unwrap_or_else(|_| "{}".to_string());
129 vec![#crate_path::prompt::PromptPart::Text(json_str)]
130 }
131 }
132 }
133}
134
135fn generate_schema_only_parts(
137 struct_name: &str,
138 struct_docs: &str,
139 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
140 crate_path: &proc_macro2::TokenStream,
141 _has_type_marker: bool,
142) -> proc_macro2::TokenStream {
143 let mut field_schema_parts = vec![];
144
145 for (i, field) in fields.iter().enumerate() {
147 let field_name = field.ident.as_ref().unwrap();
148 let field_name_str = field_name.to_string();
149 let attrs = parse_field_prompt_attrs(&field.attrs);
150
151 if field_name_str == "__type" {
155 continue;
156 }
157
158 if attrs.skip {
160 continue;
161 }
162
163 let field_docs = extract_doc_comments(&field.attrs);
165
166 let (is_vec, inner_type) = extract_vec_inner_type(&field.ty);
168
169 let remaining_fields = fields
171 .iter()
172 .skip(i + 1)
173 .filter(|f| {
174 let attrs = parse_field_prompt_attrs(&f.attrs);
175 !attrs.skip
176 })
177 .count();
178 let comma = if remaining_fields > 0 { "," } else { "" };
179
180 if is_vec {
181 let comment = if !field_docs.is_empty() {
183 format!(", // {}", field_docs)
184 } else {
185 String::new()
186 };
187
188 field_schema_parts.push(quote! {
189 {
190 let inner_schema = <#inner_type as #crate_path::prompt::ToPrompt>::prompt_schema();
191 if inner_schema.is_empty() {
192 format!(" \"{}\": \"{}[]\"{}{}", #field_name_str, stringify!(#inner_type).to_lowercase(), #comment, #comma)
194 } else {
195 let inner_lines: Vec<&str> = inner_schema.lines()
197 .skip_while(|line| line.starts_with("###") || line.trim() == "{")
198 .take_while(|line| line.trim() != "}")
199 .collect();
200 let inner_content = inner_lines.join("\n");
201 format!(" \"{}\": [\n {{\n{}\n }}\n ]{}{}",
202 #field_name_str,
203 inner_content.lines()
204 .map(|line| format!(" {}", line))
205 .collect::<Vec<_>>()
206 .join("\n"),
207 #comment,
208 #comma
209 )
210 }
211 }
212 });
213 } else {
214 let field_type = &field.ty;
216 let is_primitive = is_primitive_type(field_type);
217
218 if !is_primitive {
219 let comment = if !field_docs.is_empty() {
221 format!(", // {}", field_docs)
222 } else {
223 String::new()
224 };
225
226 field_schema_parts.push(quote! {
227 {
228 let nested_schema = <#field_type as #crate_path::prompt::ToPrompt>::prompt_schema();
229 if nested_schema.is_empty() {
230 let type_str = stringify!(#field_type).to_lowercase();
232 format!(" \"{}\": \"{}\"{}{}", #field_name_str, type_str, #comment, #comma)
233 } else {
234 let nested_lines: Vec<&str> = nested_schema.lines()
236 .skip_while(|line| line.starts_with("###") || line.trim() == "{")
237 .take_while(|line| line.trim() != "}")
238 .collect();
239
240 if nested_lines.is_empty() {
241 let type_str = stringify!(#field_type).to_lowercase();
243 format!(" \"{}\": \"{}\"{}{}", #field_name_str, type_str, #comment, #comma)
244 } else {
245 let indented_content = nested_lines.iter()
247 .map(|line| format!(" {}", line))
248 .collect::<Vec<_>>()
249 .join("\n");
250 format!(" \"{}\": {{\n{}\n }}{}{}", #field_name_str, indented_content, #comment, #comma)
251 }
252 }
253 }
254 });
255 } else {
256 let type_str = format_type_for_schema(&field.ty);
258 let comment = if !field_docs.is_empty() {
259 format!(", // {}", field_docs)
260 } else {
261 String::new()
262 };
263
264 field_schema_parts.push(quote! {
265 format!(" \"{}\": \"{}\"{}{}", #field_name_str, #type_str, #comment, #comma)
266 });
267 }
268 }
269 }
270
271 let header = if !struct_docs.is_empty() {
273 format!("### Schema for `{}`\n{}", struct_name, struct_docs)
274 } else {
275 format!("### Schema for `{}`", struct_name)
276 };
277
278 quote! {
279 {
280 let mut lines = vec![#header.to_string(), "{".to_string()];
281 #(lines.push(#field_schema_parts);)*
282 lines.push("}".to_string());
283 vec![#crate_path::prompt::PromptPart::Text(lines.join("\n"))]
284 }
285 }
286}
287
288fn extract_vec_inner_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
290 if let syn::Type::Path(type_path) = ty
291 && let Some(last_segment) = type_path.path.segments.last()
292 && last_segment.ident == "Vec"
293 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
294 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
295 {
296 return (true, Some(inner_type));
297 }
298 (false, None)
299}
300
301fn is_primitive_type(ty: &syn::Type) -> bool {
303 if let syn::Type::Path(type_path) = ty
304 && let Some(last_segment) = type_path.path.segments.last()
305 {
306 let type_name = last_segment.ident.to_string();
307 matches!(
308 type_name.as_str(),
309 "String"
310 | "str"
311 | "i8"
312 | "i16"
313 | "i32"
314 | "i64"
315 | "i128"
316 | "isize"
317 | "u8"
318 | "u16"
319 | "u32"
320 | "u64"
321 | "u128"
322 | "usize"
323 | "f32"
324 | "f64"
325 | "bool"
326 | "Vec"
327 | "Option"
328 | "HashMap"
329 | "BTreeMap"
330 | "HashSet"
331 | "BTreeSet"
332 )
333 } else {
334 true
336 }
337}
338
339fn format_type_for_schema(ty: &syn::Type) -> String {
341 match ty {
343 syn::Type::Path(type_path) => {
344 let path = &type_path.path;
345 if let Some(last_segment) = path.segments.last() {
346 let type_name = last_segment.ident.to_string();
347
348 if type_name == "Option"
350 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
351 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
352 {
353 return format!("{} | null", format_type_for_schema(inner_type));
354 }
355
356 match type_name.as_str() {
358 "String" | "str" => "string".to_string(),
359 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
360 | "u64" | "u128" | "usize" => "number".to_string(),
361 "f32" | "f64" => "number".to_string(),
362 "bool" => "boolean".to_string(),
363 "Vec" => {
364 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
365 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
366 {
367 return format!("{}[]", format_type_for_schema(inner_type));
368 }
369 "array".to_string()
370 }
371 _ => type_name.to_lowercase(),
372 }
373 } else {
374 "unknown".to_string()
375 }
376 }
377 _ => "unknown".to_string(),
378 }
379}
380
381enum PromptAttribute {
383 Skip,
384 Description(String),
385 None,
386}
387
388fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
390 for attr in attrs {
391 if attr.path().is_ident("prompt") {
392 if let Ok(meta_list) = attr.meta.require_list() {
394 let tokens = &meta_list.tokens;
395 let tokens_str = tokens.to_string();
396 if tokens_str == "skip" {
397 return PromptAttribute::Skip;
398 }
399 }
400
401 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
403 return PromptAttribute::Description(lit_str.value());
404 }
405 }
406 }
407 PromptAttribute::None
408}
409
410#[derive(Debug, Default)]
412struct FieldPromptAttrs {
413 skip: bool,
414 rename: Option<String>,
415 format_with: Option<String>,
416 image: bool,
417 example: Option<String>,
418}
419
420fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
422 let mut result = FieldPromptAttrs::default();
423
424 for attr in attrs {
425 if attr.path().is_ident("prompt") {
426 if let Ok(meta_list) = attr.meta.require_list() {
428 if let Ok(metas) =
430 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
431 {
432 for meta in metas {
433 match meta {
434 Meta::Path(path) if path.is_ident("skip") => {
435 result.skip = true;
436 }
437 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
438 if let syn::Expr::Lit(syn::ExprLit {
439 lit: syn::Lit::Str(lit_str),
440 ..
441 }) = nv.value
442 {
443 result.rename = Some(lit_str.value());
444 }
445 }
446 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
447 if let syn::Expr::Lit(syn::ExprLit {
448 lit: syn::Lit::Str(lit_str),
449 ..
450 }) = nv.value
451 {
452 result.format_with = Some(lit_str.value());
453 }
454 }
455 Meta::Path(path) if path.is_ident("image") => {
456 result.image = true;
457 }
458 Meta::NameValue(nv) if nv.path.is_ident("example") => {
459 if let syn::Expr::Lit(syn::ExprLit {
460 lit: syn::Lit::Str(lit_str),
461 ..
462 }) = nv.value
463 {
464 result.example = Some(lit_str.value());
465 }
466 }
467 _ => {}
468 }
469 }
470 } else if meta_list.tokens.to_string() == "skip" {
471 result.skip = true;
473 } else if meta_list.tokens.to_string() == "image" {
474 result.image = true;
476 }
477 }
478 }
479 }
480
481 result
482}
483
484#[proc_macro_derive(ToPrompt, attributes(prompt))]
527pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
528 let input = parse_macro_input!(input as DeriveInput);
529
530 let found_crate =
531 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
532 let crate_path = match found_crate {
533 FoundCrate::Itself => {
534 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
536 quote!(::#ident)
537 }
538 FoundCrate::Name(name) => {
539 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
540 quote!(::#ident)
541 }
542 };
543
544 match &input.data {
546 Data::Enum(data_enum) => {
547 let enum_name = &input.ident;
549 let enum_docs = extract_doc_comments(&input.attrs);
550
551 let mut prompt_lines = Vec::new();
552
553 if !enum_docs.is_empty() {
555 prompt_lines.push(format!("{}: {}", enum_name, enum_docs));
556 } else {
557 prompt_lines.push(format!("{}:", enum_name));
558 }
559 prompt_lines.push(String::new()); prompt_lines.push("Possible values:".to_string());
561
562 for variant in &data_enum.variants {
564 let variant_name = &variant.ident;
565
566 match parse_prompt_attribute(&variant.attrs) {
568 PromptAttribute::Skip => {
569 continue;
571 }
572 PromptAttribute::Description(desc) => {
573 prompt_lines.push(format!("- {}: {}", variant_name, desc));
575 }
576 PromptAttribute::None => {
577 let variant_docs = extract_doc_comments(&variant.attrs);
579 if !variant_docs.is_empty() {
580 prompt_lines.push(format!("- {}: {}", variant_name, variant_docs));
581 } else {
582 prompt_lines.push(format!("- {}", variant_name));
583 }
584 }
585 }
586 }
587
588 let prompt_string = prompt_lines.join("\n");
589 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
590
591 let mut match_arms = Vec::new();
593 for variant in &data_enum.variants {
594 let variant_name = &variant.ident;
595
596 match parse_prompt_attribute(&variant.attrs) {
598 PromptAttribute::Skip => {
599 match_arms.push(quote! {
601 Self::#variant_name => stringify!(#variant_name).to_string()
602 });
603 }
604 PromptAttribute::Description(desc) => {
605 match_arms.push(quote! {
607 Self::#variant_name => format!("{}: {}", stringify!(#variant_name), #desc)
608 });
609 }
610 PromptAttribute::None => {
611 let variant_docs = extract_doc_comments(&variant.attrs);
613 if !variant_docs.is_empty() {
614 match_arms.push(quote! {
615 Self::#variant_name => format!("{}: {}", stringify!(#variant_name), #variant_docs)
616 });
617 } else {
618 match_arms.push(quote! {
619 Self::#variant_name => stringify!(#variant_name).to_string()
620 });
621 }
622 }
623 }
624 }
625
626 let to_prompt_impl = if match_arms.is_empty() {
627 quote! {
629 fn to_prompt(&self) -> String {
630 match *self {}
631 }
632 }
633 } else {
634 quote! {
635 fn to_prompt(&self) -> String {
636 match self {
637 #(#match_arms),*
638 }
639 }
640 }
641 };
642
643 let expanded = quote! {
644 impl #impl_generics #crate_path::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
645 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
646 vec![#crate_path::prompt::PromptPart::Text(self.to_prompt())]
647 }
648
649 #to_prompt_impl
650
651 fn prompt_schema() -> String {
652 #prompt_string.to_string()
653 }
654 }
655 };
656
657 TokenStream::from(expanded)
658 }
659 Data::Struct(data_struct) => {
660 let mut template_attr = None;
662 let mut template_file_attr = None;
663 let mut mode_attr = None;
664 let mut validate_attr = false;
665 let mut type_marker_attr = false;
666
667 for attr in &input.attrs {
668 if attr.path().is_ident("prompt") {
669 if let Ok(metas) =
671 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
672 {
673 for meta in metas {
674 match meta {
675 Meta::NameValue(nv) if nv.path.is_ident("template") => {
676 if let syn::Expr::Lit(expr_lit) = nv.value
677 && let syn::Lit::Str(lit_str) = expr_lit.lit
678 {
679 template_attr = Some(lit_str.value());
680 }
681 }
682 Meta::NameValue(nv) if nv.path.is_ident("template_file") => {
683 if let syn::Expr::Lit(expr_lit) = nv.value
684 && let syn::Lit::Str(lit_str) = expr_lit.lit
685 {
686 template_file_attr = Some(lit_str.value());
687 }
688 }
689 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
690 if let syn::Expr::Lit(expr_lit) = nv.value
691 && let syn::Lit::Str(lit_str) = expr_lit.lit
692 {
693 mode_attr = Some(lit_str.value());
694 }
695 }
696 Meta::NameValue(nv) if nv.path.is_ident("validate") => {
697 if let syn::Expr::Lit(expr_lit) = nv.value
698 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
699 {
700 validate_attr = lit_bool.value();
701 }
702 }
703 Meta::NameValue(nv) if nv.path.is_ident("type_marker") => {
704 if let syn::Expr::Lit(expr_lit) = nv.value
705 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
706 {
707 type_marker_attr = lit_bool.value();
708 }
709 }
710 Meta::Path(path) if path.is_ident("type_marker") => {
711 type_marker_attr = true;
713 }
714 _ => {}
715 }
716 }
717 }
718 }
719 }
720
721 if template_attr.is_some() && template_file_attr.is_some() {
723 return syn::Error::new(
724 input.ident.span(),
725 "The `template` and `template_file` attributes are mutually exclusive. Please use only one.",
726 ).to_compile_error().into();
727 }
728
729 let template_str = if let Some(file_path) = template_file_attr {
731 let mut full_path = None;
735
736 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
738 let is_trybuild = manifest_dir.contains("target/tests/trybuild");
740
741 if !is_trybuild {
742 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
744 if candidate.exists() {
745 full_path = Some(candidate);
746 }
747 } else {
748 if let Some(target_pos) = manifest_dir.find("/target/tests/trybuild") {
754 let workspace_root = &manifest_dir[..target_pos];
755 let original_macros_dir = std::path::Path::new(workspace_root)
757 .join("crates")
758 .join("llm-toolkit-macros");
759
760 let candidate = original_macros_dir.join(&file_path);
761 if candidate.exists() {
762 full_path = Some(candidate);
763 }
764 }
765 }
766 }
767
768 if full_path.is_none() {
770 let candidate = std::path::Path::new(&file_path).to_path_buf();
771 if candidate.exists() {
772 full_path = Some(candidate);
773 }
774 }
775
776 if full_path.is_none()
779 && let Ok(current_dir) = std::env::current_dir()
780 {
781 let mut search_dir = current_dir.as_path();
782 for _ in 0..10 {
784 let macros_dir = search_dir.join("crates/llm-toolkit-macros");
786 if macros_dir.exists() {
787 let candidate = macros_dir.join(&file_path);
788 if candidate.exists() {
789 full_path = Some(candidate);
790 break;
791 }
792 }
793 let candidate = search_dir.join(&file_path);
795 if candidate.exists() {
796 full_path = Some(candidate);
797 break;
798 }
799 if let Some(parent) = search_dir.parent() {
800 search_dir = parent;
801 } else {
802 break;
803 }
804 }
805 }
806
807 if full_path.is_none() {
809 let mut error_msg = format!(
811 "Template file '{}' not found at compile time.\n\nSearched in:",
812 file_path
813 );
814
815 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
816 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
817 error_msg.push_str(&format!("\n - {}", candidate.display()));
818 }
819
820 if let Ok(current_dir) = std::env::current_dir() {
821 let candidate = current_dir.join(&file_path);
822 error_msg.push_str(&format!("\n - {}", candidate.display()));
823 }
824
825 error_msg.push_str("\n\nPlease ensure:");
826 error_msg.push_str("\n 1. The template file exists");
827 error_msg.push_str("\n 2. The path is relative to CARGO_MANIFEST_DIR");
828 error_msg.push_str("\n 3. There are no typos in the path");
829
830 return syn::Error::new(input.ident.span(), error_msg)
831 .to_compile_error()
832 .into();
833 }
834
835 let final_path = full_path.unwrap();
836
837 match std::fs::read_to_string(&final_path) {
839 Ok(content) => Some(content),
840 Err(e) => {
841 return syn::Error::new(
842 input.ident.span(),
843 format!(
844 "Failed to read template file '{}': {}\n\nPath resolved to: {}",
845 file_path,
846 e,
847 final_path.display()
848 ),
849 )
850 .to_compile_error()
851 .into();
852 }
853 }
854 } else {
855 template_attr
856 };
857
858 if validate_attr && let Some(template) = &template_str {
860 let mut env = minijinja::Environment::new();
862 if let Err(e) = env.add_template("validation", template) {
863 let warning_msg =
865 format!("Template validation warning: Invalid Jinja syntax - {}", e);
866 let warning_ident = syn::Ident::new(
867 "TEMPLATE_VALIDATION_WARNING",
868 proc_macro2::Span::call_site(),
869 );
870 let _warning_tokens = quote! {
871 #[deprecated(note = #warning_msg)]
872 const #warning_ident: () = ();
873 let _ = #warning_ident;
874 };
875 eprintln!("cargo:warning={}", warning_msg);
877 }
878
879 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
881 &fields.named
882 } else {
883 panic!("Template validation is only supported for structs with named fields.");
884 };
885
886 let field_names: std::collections::HashSet<String> = fields
887 .iter()
888 .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
889 .collect();
890
891 let placeholders = parse_template_placeholders_with_mode(template);
893
894 for (placeholder_name, _mode) in &placeholders {
895 if placeholder_name != "self" && !field_names.contains(placeholder_name) {
896 let warning_msg = format!(
897 "Template validation warning: Variable '{}' used in template but not found in struct fields",
898 placeholder_name
899 );
900 eprintln!("cargo:warning={}", warning_msg);
901 }
902 }
903 }
904
905 let name = input.ident;
906 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
907
908 let struct_docs = extract_doc_comments(&input.attrs);
910
911 let is_mode_based =
913 mode_attr.is_some() || (template_str.is_none() && struct_docs.contains("mode"));
914
915 let expanded = if is_mode_based || mode_attr.is_some() {
916 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
918 &fields.named
919 } else {
920 panic!(
921 "Mode-based prompt generation is only supported for structs with named fields."
922 );
923 };
924
925 let struct_name_str = name.to_string();
926
927 let has_default = input.attrs.iter().any(|attr| {
929 if attr.path().is_ident("derive")
930 && let Ok(meta_list) = attr.meta.require_list()
931 {
932 let tokens_str = meta_list.tokens.to_string();
933 tokens_str.contains("Default")
934 } else {
935 false
936 }
937 });
938
939 let schema_parts = generate_schema_only_parts(
950 &struct_name_str,
951 &struct_docs,
952 fields,
953 &crate_path,
954 type_marker_attr,
955 );
956
957 let example_parts = generate_example_only_parts(fields, has_default, &crate_path);
959
960 quote! {
961 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
962 fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<#crate_path::prompt::PromptPart> {
963 match mode {
964 "schema_only" => #schema_parts,
965 "example_only" => #example_parts,
966 "full" | _ => {
967 let mut parts = Vec::new();
969
970 let schema_parts = #schema_parts;
972 parts.extend(schema_parts);
973
974 parts.push(#crate_path::prompt::PromptPart::Text("\n### Example".to_string()));
976 parts.push(#crate_path::prompt::PromptPart::Text(
977 format!("Here is an example of a valid `{}` object:", #struct_name_str)
978 ));
979
980 let example_parts = #example_parts;
982 parts.extend(example_parts);
983
984 parts
985 }
986 }
987 }
988
989 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
990 self.to_prompt_parts_with_mode("full")
991 }
992
993 fn to_prompt(&self) -> String {
994 self.to_prompt_parts()
995 .into_iter()
996 .filter_map(|part| match part {
997 #crate_path::prompt::PromptPart::Text(text) => Some(text),
998 _ => None,
999 })
1000 .collect::<Vec<_>>()
1001 .join("\n")
1002 }
1003
1004 fn prompt_schema() -> String {
1005 use std::sync::OnceLock;
1006 static SCHEMA_CACHE: OnceLock<String> = OnceLock::new();
1007
1008 SCHEMA_CACHE.get_or_init(|| {
1009 let schema_parts = #schema_parts;
1010 schema_parts
1011 .into_iter()
1012 .filter_map(|part| match part {
1013 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1014 _ => None,
1015 })
1016 .collect::<Vec<_>>()
1017 .join("\n")
1018 }).clone()
1019 }
1020 }
1021 }
1022 } else if let Some(template) = template_str {
1023 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1026 &fields.named
1027 } else {
1028 panic!(
1029 "Template prompt generation is only supported for structs with named fields."
1030 );
1031 };
1032
1033 let placeholders = parse_template_placeholders_with_mode(&template);
1035 let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
1037 mode.is_some()
1038 && fields
1039 .iter()
1040 .any(|f| f.ident.as_ref().unwrap() == field_name)
1041 });
1042
1043 let mut image_field_parts = Vec::new();
1044 for f in fields.iter() {
1045 let field_name = f.ident.as_ref().unwrap();
1046 let attrs = parse_field_prompt_attrs(&f.attrs);
1047
1048 if attrs.image {
1049 image_field_parts.push(quote! {
1051 parts.extend(self.#field_name.to_prompt_parts());
1052 });
1053 }
1054 }
1055
1056 if has_mode_syntax {
1058 let mut context_fields = Vec::new();
1060 let mut modified_template = template.clone();
1061
1062 for (field_name, mode_opt) in &placeholders {
1064 if let Some(mode) = mode_opt {
1065 let unique_key = format!("{}__{}", field_name, mode);
1067
1068 let pattern = format!("{{{{ {}:{} }}}}", field_name, mode);
1070 let replacement = format!("{{{{ {} }}}}", unique_key);
1071 modified_template = modified_template.replace(&pattern, &replacement);
1072
1073 let field_ident =
1075 syn::Ident::new(field_name, proc_macro2::Span::call_site());
1076
1077 context_fields.push(quote! {
1079 context.insert(
1080 #unique_key.to_string(),
1081 minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
1082 );
1083 });
1084 }
1085 }
1086
1087 for field in fields.iter() {
1089 let field_name = field.ident.as_ref().unwrap();
1090 let field_name_str = field_name.to_string();
1091
1092 let has_mode_entry = placeholders
1094 .iter()
1095 .any(|(name, mode)| name == &field_name_str && mode.is_some());
1096
1097 if !has_mode_entry {
1098 let is_primitive = match &field.ty {
1101 syn::Type::Path(type_path) => {
1102 if let Some(segment) = type_path.path.segments.last() {
1103 let type_name = segment.ident.to_string();
1104 matches!(
1105 type_name.as_str(),
1106 "String"
1107 | "str"
1108 | "i8"
1109 | "i16"
1110 | "i32"
1111 | "i64"
1112 | "i128"
1113 | "isize"
1114 | "u8"
1115 | "u16"
1116 | "u32"
1117 | "u64"
1118 | "u128"
1119 | "usize"
1120 | "f32"
1121 | "f64"
1122 | "bool"
1123 | "char"
1124 )
1125 } else {
1126 false
1127 }
1128 }
1129 _ => false,
1130 };
1131
1132 if is_primitive {
1133 context_fields.push(quote! {
1134 context.insert(
1135 #field_name_str.to_string(),
1136 minijinja::Value::from_serialize(&self.#field_name)
1137 );
1138 });
1139 } else {
1140 context_fields.push(quote! {
1142 context.insert(
1143 #field_name_str.to_string(),
1144 minijinja::Value::from(self.#field_name.to_prompt())
1145 );
1146 });
1147 }
1148 }
1149 }
1150
1151 quote! {
1152 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1153 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1154 let mut parts = Vec::new();
1155
1156 #(#image_field_parts)*
1158
1159 let text = {
1161 let mut env = minijinja::Environment::new();
1162 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1163 panic!("Failed to parse template: {}", e)
1164 });
1165
1166 let tmpl = env.get_template("prompt").unwrap();
1167
1168 let mut context = std::collections::HashMap::new();
1169 #(#context_fields)*
1170
1171 tmpl.render(context).unwrap_or_else(|e| {
1172 format!("Failed to render prompt: {}", e)
1173 })
1174 };
1175
1176 if !text.is_empty() {
1177 parts.push(#crate_path::prompt::PromptPart::Text(text));
1178 }
1179
1180 parts
1181 }
1182
1183 fn to_prompt(&self) -> String {
1184 let mut env = minijinja::Environment::new();
1186 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1187 panic!("Failed to parse template: {}", e)
1188 });
1189
1190 let tmpl = env.get_template("prompt").unwrap();
1191
1192 let mut context = std::collections::HashMap::new();
1193 #(#context_fields)*
1194
1195 tmpl.render(context).unwrap_or_else(|e| {
1196 format!("Failed to render prompt: {}", e)
1197 })
1198 }
1199
1200 fn prompt_schema() -> String {
1201 String::new() }
1203 }
1204 }
1205 } else {
1206 quote! {
1208 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1209 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1210 let mut parts = Vec::new();
1211
1212 #(#image_field_parts)*
1214
1215 let text = #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1217 format!("Failed to render prompt: {}", e)
1218 });
1219 if !text.is_empty() {
1220 parts.push(#crate_path::prompt::PromptPart::Text(text));
1221 }
1222
1223 parts
1224 }
1225
1226 fn to_prompt(&self) -> String {
1227 #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1228 format!("Failed to render prompt: {}", e)
1229 })
1230 }
1231
1232 fn prompt_schema() -> String {
1233 String::new() }
1235 }
1236 }
1237 }
1238 } else {
1239 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1242 &fields.named
1243 } else {
1244 panic!(
1245 "Default prompt generation is only supported for structs with named fields."
1246 );
1247 };
1248
1249 let mut text_field_parts = Vec::new();
1251 let mut image_field_parts = Vec::new();
1252
1253 for f in fields.iter() {
1254 let field_name = f.ident.as_ref().unwrap();
1255 let attrs = parse_field_prompt_attrs(&f.attrs);
1256
1257 if attrs.skip {
1259 continue;
1260 }
1261
1262 if attrs.image {
1263 image_field_parts.push(quote! {
1265 parts.extend(self.#field_name.to_prompt_parts());
1266 });
1267 } else {
1268 let key = if let Some(rename) = attrs.rename {
1274 rename
1275 } else {
1276 let doc_comment = extract_doc_comments(&f.attrs);
1277 if !doc_comment.is_empty() {
1278 doc_comment
1279 } else {
1280 field_name.to_string()
1281 }
1282 };
1283
1284 let value_expr = if let Some(format_with) = attrs.format_with {
1286 let func_path: syn::Path =
1288 syn::parse_str(&format_with).unwrap_or_else(|_| {
1289 panic!("Invalid function path: {}", format_with)
1290 });
1291 quote! { #func_path(&self.#field_name) }
1292 } else {
1293 quote! { self.#field_name.to_prompt() }
1294 };
1295
1296 text_field_parts.push(quote! {
1297 text_parts.push(format!("{}: {}", #key, #value_expr));
1298 });
1299 }
1300 }
1301
1302 quote! {
1304 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1305 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1306 let mut parts = Vec::new();
1307
1308 #(#image_field_parts)*
1310
1311 let mut text_parts = Vec::new();
1313 #(#text_field_parts)*
1314
1315 if !text_parts.is_empty() {
1316 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1317 }
1318
1319 parts
1320 }
1321
1322 fn to_prompt(&self) -> String {
1323 let mut text_parts = Vec::new();
1324 #(#text_field_parts)*
1325 text_parts.join("\n")
1326 }
1327
1328 fn prompt_schema() -> String {
1329 String::new() }
1331 }
1332 }
1333 };
1334
1335 TokenStream::from(expanded)
1336 }
1337 Data::Union(_) => {
1338 panic!("`#[derive(ToPrompt)]` is not supported for unions");
1339 }
1340 }
1341}
1342
1343#[derive(Debug, Clone)]
1345struct TargetInfo {
1346 name: String,
1347 template: Option<String>,
1348 field_configs: std::collections::HashMap<String, FieldTargetConfig>,
1349}
1350
1351#[derive(Debug, Clone, Default)]
1353struct FieldTargetConfig {
1354 skip: bool,
1355 rename: Option<String>,
1356 format_with: Option<String>,
1357 image: bool,
1358 include_only: bool, }
1360
1361fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
1363 let mut configs = Vec::new();
1364
1365 for attr in attrs {
1366 if attr.path().is_ident("prompt_for")
1367 && let Ok(meta_list) = attr.meta.require_list()
1368 {
1369 if meta_list.tokens.to_string() == "skip" {
1371 let config = FieldTargetConfig {
1373 skip: true,
1374 ..Default::default()
1375 };
1376 configs.push(("*".to_string(), config));
1377 } else if let Ok(metas) =
1378 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1379 {
1380 let mut target_name = None;
1381 let mut config = FieldTargetConfig::default();
1382
1383 for meta in metas {
1384 match meta {
1385 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1386 if let syn::Expr::Lit(syn::ExprLit {
1387 lit: syn::Lit::Str(lit_str),
1388 ..
1389 }) = nv.value
1390 {
1391 target_name = Some(lit_str.value());
1392 }
1393 }
1394 Meta::Path(path) if path.is_ident("skip") => {
1395 config.skip = true;
1396 }
1397 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
1398 if let syn::Expr::Lit(syn::ExprLit {
1399 lit: syn::Lit::Str(lit_str),
1400 ..
1401 }) = nv.value
1402 {
1403 config.rename = Some(lit_str.value());
1404 }
1405 }
1406 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
1407 if let syn::Expr::Lit(syn::ExprLit {
1408 lit: syn::Lit::Str(lit_str),
1409 ..
1410 }) = nv.value
1411 {
1412 config.format_with = Some(lit_str.value());
1413 }
1414 }
1415 Meta::Path(path) if path.is_ident("image") => {
1416 config.image = true;
1417 }
1418 _ => {}
1419 }
1420 }
1421
1422 if let Some(name) = target_name {
1423 config.include_only = true;
1424 configs.push((name, config));
1425 }
1426 }
1427 }
1428 }
1429
1430 configs
1431}
1432
1433fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
1435 let mut targets = Vec::new();
1436
1437 for attr in attrs {
1438 if attr.path().is_ident("prompt_for")
1439 && let Ok(meta_list) = attr.meta.require_list()
1440 && let Ok(metas) =
1441 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1442 {
1443 let mut target_name = None;
1444 let mut template = None;
1445
1446 for meta in metas {
1447 match meta {
1448 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1449 if let syn::Expr::Lit(syn::ExprLit {
1450 lit: syn::Lit::Str(lit_str),
1451 ..
1452 }) = nv.value
1453 {
1454 target_name = Some(lit_str.value());
1455 }
1456 }
1457 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1458 if let syn::Expr::Lit(syn::ExprLit {
1459 lit: syn::Lit::Str(lit_str),
1460 ..
1461 }) = nv.value
1462 {
1463 template = Some(lit_str.value());
1464 }
1465 }
1466 _ => {}
1467 }
1468 }
1469
1470 if let Some(name) = target_name {
1471 targets.push(TargetInfo {
1472 name,
1473 template,
1474 field_configs: std::collections::HashMap::new(),
1475 });
1476 }
1477 }
1478 }
1479
1480 targets
1481}
1482
1483#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
1484pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
1485 let input = parse_macro_input!(input as DeriveInput);
1486
1487 let found_crate =
1488 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1489 let crate_path = match found_crate {
1490 FoundCrate::Itself => {
1491 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1493 quote!(::#ident)
1494 }
1495 FoundCrate::Name(name) => {
1496 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1497 quote!(::#ident)
1498 }
1499 };
1500
1501 let data_struct = match &input.data {
1503 Data::Struct(data) => data,
1504 _ => {
1505 return syn::Error::new(
1506 input.ident.span(),
1507 "`#[derive(ToPromptSet)]` is only supported for structs",
1508 )
1509 .to_compile_error()
1510 .into();
1511 }
1512 };
1513
1514 let fields = match &data_struct.fields {
1515 syn::Fields::Named(fields) => &fields.named,
1516 _ => {
1517 return syn::Error::new(
1518 input.ident.span(),
1519 "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
1520 )
1521 .to_compile_error()
1522 .into();
1523 }
1524 };
1525
1526 let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
1528
1529 for field in fields.iter() {
1531 let field_name = field.ident.as_ref().unwrap().to_string();
1532 let field_configs = parse_prompt_for_attrs(&field.attrs);
1533
1534 for (target_name, config) in field_configs {
1535 if target_name == "*" {
1536 for target in &mut targets {
1538 target
1539 .field_configs
1540 .entry(field_name.clone())
1541 .or_insert_with(FieldTargetConfig::default)
1542 .skip = config.skip;
1543 }
1544 } else {
1545 let target_exists = targets.iter().any(|t| t.name == target_name);
1547 if !target_exists {
1548 targets.push(TargetInfo {
1550 name: target_name.clone(),
1551 template: None,
1552 field_configs: std::collections::HashMap::new(),
1553 });
1554 }
1555
1556 let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
1557
1558 target.field_configs.insert(field_name.clone(), config);
1559 }
1560 }
1561 }
1562
1563 let mut match_arms = Vec::new();
1565
1566 for target in &targets {
1567 let target_name = &target.name;
1568
1569 if let Some(template_str) = &target.template {
1570 let mut image_parts = Vec::new();
1572
1573 for field in fields.iter() {
1574 let field_name = field.ident.as_ref().unwrap();
1575 let field_name_str = field_name.to_string();
1576
1577 if let Some(config) = target.field_configs.get(&field_name_str)
1578 && config.image
1579 {
1580 image_parts.push(quote! {
1581 parts.extend(self.#field_name.to_prompt_parts());
1582 });
1583 }
1584 }
1585
1586 match_arms.push(quote! {
1587 #target_name => {
1588 let mut parts = Vec::new();
1589
1590 #(#image_parts)*
1591
1592 let text = #crate_path::prompt::render_prompt(#template_str, self)
1593 .map_err(|e| #crate_path::prompt::PromptSetError::RenderFailed {
1594 target: #target_name.to_string(),
1595 source: e,
1596 })?;
1597
1598 if !text.is_empty() {
1599 parts.push(#crate_path::prompt::PromptPart::Text(text));
1600 }
1601
1602 Ok(parts)
1603 }
1604 });
1605 } else {
1606 let mut text_field_parts = Vec::new();
1608 let mut image_field_parts = Vec::new();
1609
1610 for field in fields.iter() {
1611 let field_name = field.ident.as_ref().unwrap();
1612 let field_name_str = field_name.to_string();
1613
1614 let config = target.field_configs.get(&field_name_str);
1616
1617 if let Some(cfg) = config
1619 && cfg.skip
1620 {
1621 continue;
1622 }
1623
1624 let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
1628 let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
1629 .iter()
1630 .any(|(name, _)| name != "*");
1631
1632 if has_any_target_specific_config && !is_explicitly_for_this_target {
1633 continue;
1634 }
1635
1636 if let Some(cfg) = config {
1637 if cfg.image {
1638 image_field_parts.push(quote! {
1639 parts.extend(self.#field_name.to_prompt_parts());
1640 });
1641 } else {
1642 let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
1643
1644 let value_expr = if let Some(format_with) = &cfg.format_with {
1645 match syn::parse_str::<syn::Path>(format_with) {
1647 Ok(func_path) => quote! { #func_path(&self.#field_name) },
1648 Err(_) => {
1649 let error_msg = format!(
1651 "Invalid function path in format_with: '{}'",
1652 format_with
1653 );
1654 quote! {
1655 compile_error!(#error_msg);
1656 String::new()
1657 }
1658 }
1659 }
1660 } else {
1661 quote! { self.#field_name.to_prompt() }
1662 };
1663
1664 text_field_parts.push(quote! {
1665 text_parts.push(format!("{}: {}", #key, #value_expr));
1666 });
1667 }
1668 } else {
1669 text_field_parts.push(quote! {
1671 text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
1672 });
1673 }
1674 }
1675
1676 match_arms.push(quote! {
1677 #target_name => {
1678 let mut parts = Vec::new();
1679
1680 #(#image_field_parts)*
1681
1682 let mut text_parts = Vec::new();
1683 #(#text_field_parts)*
1684
1685 if !text_parts.is_empty() {
1686 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1687 }
1688
1689 Ok(parts)
1690 }
1691 });
1692 }
1693 }
1694
1695 let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
1697
1698 match_arms.push(quote! {
1700 _ => {
1701 let available = vec![#(#target_names.to_string()),*];
1702 Err(#crate_path::prompt::PromptSetError::TargetNotFound {
1703 target: target.to_string(),
1704 available,
1705 })
1706 }
1707 });
1708
1709 let struct_name = &input.ident;
1710 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1711
1712 let expanded = quote! {
1713 impl #impl_generics #crate_path::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
1714 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<#crate_path::prompt::PromptPart>, #crate_path::prompt::PromptSetError> {
1715 match target {
1716 #(#match_arms)*
1717 }
1718 }
1719 }
1720 };
1721
1722 TokenStream::from(expanded)
1723}
1724
1725struct TypeList {
1727 types: Punctuated<syn::Type, Token![,]>,
1728}
1729
1730impl Parse for TypeList {
1731 fn parse(input: ParseStream) -> syn::Result<Self> {
1732 Ok(TypeList {
1733 types: Punctuated::parse_terminated(input)?,
1734 })
1735 }
1736}
1737
1738#[proc_macro]
1762pub fn examples_section(input: TokenStream) -> TokenStream {
1763 let input = parse_macro_input!(input as TypeList);
1764
1765 let found_crate =
1766 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1767 let _crate_path = match found_crate {
1768 FoundCrate::Itself => quote!(crate),
1769 FoundCrate::Name(name) => {
1770 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1771 quote!(::#ident)
1772 }
1773 };
1774
1775 let mut type_sections = Vec::new();
1777
1778 for ty in input.types.iter() {
1779 let type_name_str = quote!(#ty).to_string();
1781
1782 type_sections.push(quote! {
1784 {
1785 let type_name = #type_name_str;
1786 let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
1787 format!("---\n#### `{}`\n{}", type_name, json_example)
1788 }
1789 });
1790 }
1791
1792 let expanded = quote! {
1794 {
1795 let mut sections = Vec::new();
1796 sections.push("---".to_string());
1797 sections.push("### Examples".to_string());
1798 sections.push("".to_string());
1799 sections.push("Here are examples of the data structures you should use.".to_string());
1800 sections.push("".to_string());
1801
1802 #(sections.push(#type_sections);)*
1803
1804 sections.push("---".to_string());
1805
1806 sections.join("\n")
1807 }
1808 };
1809
1810 TokenStream::from(expanded)
1811}
1812
1813fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
1815 for attr in attrs {
1816 if attr.path().is_ident("prompt_for")
1817 && let Ok(meta_list) = attr.meta.require_list()
1818 && let Ok(metas) =
1819 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1820 {
1821 let mut target_type = None;
1822 let mut template = None;
1823
1824 for meta in metas {
1825 match meta {
1826 Meta::NameValue(nv) if nv.path.is_ident("target") => {
1827 if let syn::Expr::Lit(syn::ExprLit {
1828 lit: syn::Lit::Str(lit_str),
1829 ..
1830 }) = nv.value
1831 {
1832 target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
1834 }
1835 }
1836 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1837 if let syn::Expr::Lit(syn::ExprLit {
1838 lit: syn::Lit::Str(lit_str),
1839 ..
1840 }) = nv.value
1841 {
1842 template = Some(lit_str.value());
1843 }
1844 }
1845 _ => {}
1846 }
1847 }
1848
1849 if let (Some(target), Some(tmpl)) = (target_type, template) {
1850 return (target, tmpl);
1851 }
1852 }
1853 }
1854
1855 panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
1856}
1857
1858#[proc_macro_attribute]
1892pub fn define_intent(_attr: TokenStream, item: TokenStream) -> TokenStream {
1893 let input = parse_macro_input!(item as DeriveInput);
1894
1895 let found_crate =
1896 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1897 let crate_path = match found_crate {
1898 FoundCrate::Itself => {
1899 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1901 quote!(::#ident)
1902 }
1903 FoundCrate::Name(name) => {
1904 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1905 quote!(::#ident)
1906 }
1907 };
1908
1909 let enum_data = match &input.data {
1911 Data::Enum(data) => data,
1912 _ => {
1913 return syn::Error::new(
1914 input.ident.span(),
1915 "`#[define_intent]` can only be applied to enums",
1916 )
1917 .to_compile_error()
1918 .into();
1919 }
1920 };
1921
1922 let mut prompt_template = None;
1924 let mut extractor_tag = None;
1925 let mut mode = None;
1926
1927 for attr in &input.attrs {
1928 if attr.path().is_ident("intent")
1929 && let Ok(metas) =
1930 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1931 {
1932 for meta in metas {
1933 match meta {
1934 Meta::NameValue(nv) if nv.path.is_ident("prompt") => {
1935 if let syn::Expr::Lit(syn::ExprLit {
1936 lit: syn::Lit::Str(lit_str),
1937 ..
1938 }) = nv.value
1939 {
1940 prompt_template = Some(lit_str.value());
1941 }
1942 }
1943 Meta::NameValue(nv) if nv.path.is_ident("extractor_tag") => {
1944 if let syn::Expr::Lit(syn::ExprLit {
1945 lit: syn::Lit::Str(lit_str),
1946 ..
1947 }) = nv.value
1948 {
1949 extractor_tag = Some(lit_str.value());
1950 }
1951 }
1952 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
1953 if let syn::Expr::Lit(syn::ExprLit {
1954 lit: syn::Lit::Str(lit_str),
1955 ..
1956 }) = nv.value
1957 {
1958 mode = Some(lit_str.value());
1959 }
1960 }
1961 _ => {}
1962 }
1963 }
1964 }
1965 }
1966
1967 let mode = mode.unwrap_or_else(|| "single".to_string());
1969
1970 if mode != "single" && mode != "multi_tag" {
1972 return syn::Error::new(
1973 input.ident.span(),
1974 "`mode` must be either \"single\" or \"multi_tag\"",
1975 )
1976 .to_compile_error()
1977 .into();
1978 }
1979
1980 let prompt_template = match prompt_template {
1982 Some(p) => p,
1983 None => {
1984 return syn::Error::new(
1985 input.ident.span(),
1986 "`#[intent(...)]` attribute must include `prompt = \"...\"`",
1987 )
1988 .to_compile_error()
1989 .into();
1990 }
1991 };
1992
1993 if mode == "multi_tag" {
1995 let enum_name = &input.ident;
1996 let actions_doc = generate_multi_tag_actions_doc(&enum_data.variants);
1997 return generate_multi_tag_output(
1998 &input,
1999 enum_name,
2000 enum_data,
2001 prompt_template,
2002 actions_doc,
2003 );
2004 }
2005
2006 let extractor_tag = match extractor_tag {
2008 Some(t) => t,
2009 None => {
2010 return syn::Error::new(
2011 input.ident.span(),
2012 "`#[intent(...)]` attribute must include `extractor_tag = \"...\"`",
2013 )
2014 .to_compile_error()
2015 .into();
2016 }
2017 };
2018
2019 let enum_name = &input.ident;
2021 let enum_docs = extract_doc_comments(&input.attrs);
2022
2023 let mut intents_doc_lines = Vec::new();
2024
2025 if !enum_docs.is_empty() {
2027 intents_doc_lines.push(format!("{}: {}", enum_name, enum_docs));
2028 } else {
2029 intents_doc_lines.push(format!("{}:", enum_name));
2030 }
2031 intents_doc_lines.push(String::new()); intents_doc_lines.push("Possible values:".to_string());
2033
2034 for variant in &enum_data.variants {
2036 let variant_name = &variant.ident;
2037 let variant_docs = extract_doc_comments(&variant.attrs);
2038
2039 if !variant_docs.is_empty() {
2040 intents_doc_lines.push(format!("- {}: {}", variant_name, variant_docs));
2041 } else {
2042 intents_doc_lines.push(format!("- {}", variant_name));
2043 }
2044 }
2045
2046 let intents_doc_str = intents_doc_lines.join("\n");
2047
2048 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2050 let user_variables: Vec<String> = placeholders
2051 .iter()
2052 .filter_map(|(name, _)| {
2053 if name != "intents_doc" {
2054 Some(name.clone())
2055 } else {
2056 None
2057 }
2058 })
2059 .collect();
2060
2061 let enum_name_str = enum_name.to_string();
2063 let snake_case_name = to_snake_case(&enum_name_str);
2064 let function_name = syn::Ident::new(
2065 &format!("build_{}_prompt", snake_case_name),
2066 proc_macro2::Span::call_site(),
2067 );
2068
2069 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2071 .iter()
2072 .map(|var| {
2073 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2074 quote! { #ident: &str }
2075 })
2076 .collect();
2077
2078 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2080 .iter()
2081 .map(|var| {
2082 let var_str = var.clone();
2083 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2084 quote! {
2085 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2086 }
2087 })
2088 .collect();
2089
2090 let converted_template = prompt_template.clone();
2092
2093 let extractor_name = syn::Ident::new(
2095 &format!("{}Extractor", enum_name),
2096 proc_macro2::Span::call_site(),
2097 );
2098
2099 let filtered_attrs: Vec<_> = input
2101 .attrs
2102 .iter()
2103 .filter(|attr| !attr.path().is_ident("intent"))
2104 .collect();
2105
2106 let vis = &input.vis;
2108 let generics = &input.generics;
2109 let variants = &enum_data.variants;
2110 let enum_output = quote! {
2111 #(#filtered_attrs)*
2112 #vis enum #enum_name #generics {
2113 #variants
2114 }
2115 };
2116
2117 let expanded = quote! {
2119 #enum_output
2121
2122 pub fn #function_name(#(#function_params),*) -> String {
2124 let mut env = minijinja::Environment::new();
2125 env.add_template("prompt", #converted_template)
2126 .expect("Failed to parse intent prompt template");
2127
2128 let tmpl = env.get_template("prompt").unwrap();
2129
2130 let mut __template_context = std::collections::HashMap::new();
2131
2132 __template_context.insert("intents_doc".to_string(), minijinja::Value::from(#intents_doc_str));
2134
2135 #(#context_insertions)*
2137
2138 tmpl.render(&__template_context)
2139 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2140 }
2141
2142 pub struct #extractor_name;
2144
2145 impl #extractor_name {
2146 pub const EXTRACTOR_TAG: &'static str = #extractor_tag;
2147 }
2148
2149 impl #crate_path::intent::IntentExtractor<#enum_name> for #extractor_name {
2150 fn extract_intent(&self, response: &str) -> Result<#enum_name, #crate_path::intent::IntentExtractionError> {
2151 #crate_path::intent::extract_intent_from_response(response, Self::EXTRACTOR_TAG)
2153 }
2154 }
2155 };
2156
2157 TokenStream::from(expanded)
2158}
2159
2160fn to_snake_case(s: &str) -> String {
2162 let mut result = String::new();
2163 let mut prev_upper = false;
2164
2165 for (i, ch) in s.chars().enumerate() {
2166 if ch.is_uppercase() {
2167 if i > 0 && !prev_upper {
2168 result.push('_');
2169 }
2170 result.push(ch.to_lowercase().next().unwrap());
2171 prev_upper = true;
2172 } else {
2173 result.push(ch);
2174 prev_upper = false;
2175 }
2176 }
2177
2178 result
2179}
2180
2181#[derive(Debug, Default)]
2183struct ActionAttrs {
2184 tag: Option<String>,
2185}
2186
2187fn parse_action_attrs(attrs: &[syn::Attribute]) -> ActionAttrs {
2188 let mut result = ActionAttrs::default();
2189
2190 for attr in attrs {
2191 if attr.path().is_ident("action")
2192 && let Ok(meta_list) = attr.meta.require_list()
2193 && let Ok(metas) =
2194 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2195 {
2196 for meta in metas {
2197 if let Meta::NameValue(nv) = meta
2198 && nv.path.is_ident("tag")
2199 && let syn::Expr::Lit(syn::ExprLit {
2200 lit: syn::Lit::Str(lit_str),
2201 ..
2202 }) = nv.value
2203 {
2204 result.tag = Some(lit_str.value());
2205 }
2206 }
2207 }
2208 }
2209
2210 result
2211}
2212
2213#[derive(Debug, Default)]
2215struct FieldActionAttrs {
2216 is_attribute: bool,
2217 is_inner_text: bool,
2218}
2219
2220fn parse_field_action_attrs(attrs: &[syn::Attribute]) -> FieldActionAttrs {
2221 let mut result = FieldActionAttrs::default();
2222
2223 for attr in attrs {
2224 if attr.path().is_ident("action")
2225 && let Ok(meta_list) = attr.meta.require_list()
2226 {
2227 let tokens_str = meta_list.tokens.to_string();
2228 if tokens_str == "attribute" {
2229 result.is_attribute = true;
2230 } else if tokens_str == "inner_text" {
2231 result.is_inner_text = true;
2232 }
2233 }
2234 }
2235
2236 result
2237}
2238
2239fn generate_multi_tag_actions_doc(
2241 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2242) -> String {
2243 let mut doc_lines = Vec::new();
2244
2245 for variant in variants {
2246 let action_attrs = parse_action_attrs(&variant.attrs);
2247
2248 if let Some(tag) = action_attrs.tag {
2249 let variant_docs = extract_doc_comments(&variant.attrs);
2250
2251 match &variant.fields {
2252 syn::Fields::Unit => {
2253 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2255 }
2256 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
2257 doc_lines.push(format!("- `<{}>...</{}>`: {}", tag, tag, variant_docs));
2259 }
2260 syn::Fields::Named(fields) => {
2261 let mut attrs_str = Vec::new();
2263 let mut has_inner_text = false;
2264
2265 for field in &fields.named {
2266 let field_name = field.ident.as_ref().unwrap();
2267 let field_attrs = parse_field_action_attrs(&field.attrs);
2268
2269 if field_attrs.is_attribute {
2270 attrs_str.push(format!("{}=\"...\"", field_name));
2271 } else if field_attrs.is_inner_text {
2272 has_inner_text = true;
2273 }
2274 }
2275
2276 let attrs_part = if !attrs_str.is_empty() {
2277 format!(" {}", attrs_str.join(" "))
2278 } else {
2279 String::new()
2280 };
2281
2282 if has_inner_text {
2283 doc_lines.push(format!(
2284 "- `<{}{}>...</{}>`: {}",
2285 tag, attrs_part, tag, variant_docs
2286 ));
2287 } else if !attrs_str.is_empty() {
2288 doc_lines.push(format!("- `<{}{} />`: {}", tag, attrs_part, variant_docs));
2289 } else {
2290 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2291 }
2292
2293 for field in &fields.named {
2295 let field_name = field.ident.as_ref().unwrap();
2296 let field_attrs = parse_field_action_attrs(&field.attrs);
2297 let field_docs = extract_doc_comments(&field.attrs);
2298
2299 if field_attrs.is_attribute {
2300 doc_lines
2301 .push(format!(" - `{}` (attribute): {}", field_name, field_docs));
2302 } else if field_attrs.is_inner_text {
2303 doc_lines
2304 .push(format!(" - `{}` (inner_text): {}", field_name, field_docs));
2305 }
2306 }
2307 }
2308 _ => {
2309 }
2311 }
2312 }
2313 }
2314
2315 doc_lines.join("\n")
2316}
2317
2318fn generate_tags_regex(
2320 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2321) -> String {
2322 let mut tag_names = Vec::new();
2323
2324 for variant in variants {
2325 let action_attrs = parse_action_attrs(&variant.attrs);
2326 if let Some(tag) = action_attrs.tag {
2327 tag_names.push(tag);
2328 }
2329 }
2330
2331 if tag_names.is_empty() {
2332 return String::new();
2333 }
2334
2335 let tags_pattern = tag_names.join("|");
2336 format!(
2339 r"(?is)<(?:{})\b[^>]*/>|<(?:{})\b[^>]*>.*?</(?:{})>",
2340 tags_pattern, tags_pattern, tags_pattern
2341 )
2342}
2343
2344fn generate_multi_tag_output(
2346 input: &DeriveInput,
2347 enum_name: &syn::Ident,
2348 enum_data: &syn::DataEnum,
2349 prompt_template: String,
2350 actions_doc: String,
2351) -> TokenStream {
2352 let found_crate =
2353 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2354 let crate_path = match found_crate {
2355 FoundCrate::Itself => {
2356 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2358 quote!(::#ident)
2359 }
2360 FoundCrate::Name(name) => {
2361 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2362 quote!(::#ident)
2363 }
2364 };
2365
2366 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2368 let user_variables: Vec<String> = placeholders
2369 .iter()
2370 .filter_map(|(name, _)| {
2371 if name != "actions_doc" {
2372 Some(name.clone())
2373 } else {
2374 None
2375 }
2376 })
2377 .collect();
2378
2379 let enum_name_str = enum_name.to_string();
2381 let snake_case_name = to_snake_case(&enum_name_str);
2382 let function_name = syn::Ident::new(
2383 &format!("build_{}_prompt", snake_case_name),
2384 proc_macro2::Span::call_site(),
2385 );
2386
2387 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2389 .iter()
2390 .map(|var| {
2391 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2392 quote! { #ident: &str }
2393 })
2394 .collect();
2395
2396 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2398 .iter()
2399 .map(|var| {
2400 let var_str = var.clone();
2401 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2402 quote! {
2403 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2404 }
2405 })
2406 .collect();
2407
2408 let extractor_name = syn::Ident::new(
2410 &format!("{}Extractor", enum_name),
2411 proc_macro2::Span::call_site(),
2412 );
2413
2414 let filtered_attrs: Vec<_> = input
2416 .attrs
2417 .iter()
2418 .filter(|attr| !attr.path().is_ident("intent"))
2419 .collect();
2420
2421 let filtered_variants: Vec<proc_macro2::TokenStream> = enum_data
2423 .variants
2424 .iter()
2425 .map(|variant| {
2426 let variant_name = &variant.ident;
2427 let variant_attrs: Vec<_> = variant
2428 .attrs
2429 .iter()
2430 .filter(|attr| !attr.path().is_ident("action"))
2431 .collect();
2432 let fields = &variant.fields;
2433
2434 let filtered_fields = match fields {
2436 syn::Fields::Named(named_fields) => {
2437 let filtered: Vec<_> = named_fields
2438 .named
2439 .iter()
2440 .map(|field| {
2441 let field_name = &field.ident;
2442 let field_type = &field.ty;
2443 let field_vis = &field.vis;
2444 let filtered_attrs: Vec<_> = field
2445 .attrs
2446 .iter()
2447 .filter(|attr| !attr.path().is_ident("action"))
2448 .collect();
2449 quote! {
2450 #(#filtered_attrs)*
2451 #field_vis #field_name: #field_type
2452 }
2453 })
2454 .collect();
2455 quote! { { #(#filtered,)* } }
2456 }
2457 syn::Fields::Unnamed(unnamed_fields) => {
2458 let types: Vec<_> = unnamed_fields
2459 .unnamed
2460 .iter()
2461 .map(|field| {
2462 let field_type = &field.ty;
2463 quote! { #field_type }
2464 })
2465 .collect();
2466 quote! { (#(#types),*) }
2467 }
2468 syn::Fields::Unit => quote! {},
2469 };
2470
2471 quote! {
2472 #(#variant_attrs)*
2473 #variant_name #filtered_fields
2474 }
2475 })
2476 .collect();
2477
2478 let vis = &input.vis;
2479 let generics = &input.generics;
2480
2481 let parsing_arms = generate_parsing_arms(&enum_data.variants, enum_name);
2483
2484 let tags_regex = generate_tags_regex(&enum_data.variants);
2486
2487 let expanded = quote! {
2488 #(#filtered_attrs)*
2490 #vis enum #enum_name #generics {
2491 #(#filtered_variants),*
2492 }
2493
2494 pub fn #function_name(#(#function_params),*) -> String {
2496 let mut env = minijinja::Environment::new();
2497 env.add_template("prompt", #prompt_template)
2498 .expect("Failed to parse intent prompt template");
2499
2500 let tmpl = env.get_template("prompt").unwrap();
2501
2502 let mut __template_context = std::collections::HashMap::new();
2503
2504 __template_context.insert("actions_doc".to_string(), minijinja::Value::from(#actions_doc));
2506
2507 #(#context_insertions)*
2509
2510 tmpl.render(&__template_context)
2511 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2512 }
2513
2514 pub struct #extractor_name;
2516
2517 impl #extractor_name {
2518 fn parse_single_action(&self, text: &str) -> Option<#enum_name> {
2519 use ::quick_xml::events::Event;
2520 use ::quick_xml::Reader;
2521
2522 let mut actions = Vec::new();
2523 let mut reader = Reader::from_str(text);
2524 reader.config_mut().trim_text(true);
2525
2526 let mut buf = Vec::new();
2527
2528 loop {
2529 match reader.read_event_into(&mut buf) {
2530 Ok(Event::Start(e)) => {
2531 let owned_e = e.into_owned();
2532 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2533 let is_empty = false;
2534
2535 #parsing_arms
2536 }
2537 Ok(Event::Empty(e)) => {
2538 let owned_e = e.into_owned();
2539 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2540 let is_empty = true;
2541
2542 #parsing_arms
2543 }
2544 Ok(Event::Eof) => break,
2545 Err(_) => {
2546 break;
2548 }
2549 _ => {}
2550 }
2551 buf.clear();
2552 }
2553
2554 actions.into_iter().next()
2555 }
2556
2557 pub fn extract_actions(&self, text: &str) -> Result<Vec<#enum_name>, #crate_path::intent::IntentError> {
2558 use ::quick_xml::events::Event;
2559 use ::quick_xml::Reader;
2560
2561 let mut actions = Vec::new();
2562 let mut reader = Reader::from_str(text);
2563 reader.config_mut().trim_text(true);
2564
2565 let mut buf = Vec::new();
2566
2567 loop {
2568 match reader.read_event_into(&mut buf) {
2569 Ok(Event::Start(e)) => {
2570 let owned_e = e.into_owned();
2571 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2572 let is_empty = false;
2573
2574 #parsing_arms
2575 }
2576 Ok(Event::Empty(e)) => {
2577 let owned_e = e.into_owned();
2578 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2579 let is_empty = true;
2580
2581 #parsing_arms
2582 }
2583 Ok(Event::Eof) => break,
2584 Err(_) => {
2585 break;
2587 }
2588 _ => {}
2589 }
2590 buf.clear();
2591 }
2592
2593 Ok(actions)
2594 }
2595
2596 pub fn transform_actions<F>(&self, text: &str, mut transformer: F) -> String
2597 where
2598 F: FnMut(#enum_name) -> String,
2599 {
2600 use ::regex::Regex;
2601
2602 let regex_pattern = #tags_regex;
2603 if regex_pattern.is_empty() {
2604 return text.to_string();
2605 }
2606
2607 let re = Regex::new(®ex_pattern).unwrap_or_else(|e| {
2608 panic!("Failed to compile regex for action tags: {}", e);
2609 });
2610
2611 re.replace_all(text, |caps: &::regex::Captures| {
2612 let matched = caps.get(0).map(|m| m.as_str()).unwrap_or("");
2613
2614 if let Some(action) = self.parse_single_action(matched) {
2616 transformer(action)
2617 } else {
2618 matched.to_string()
2620 }
2621 }).to_string()
2622 }
2623
2624 pub fn strip_actions(&self, text: &str) -> String {
2625 self.transform_actions(text, |_| String::new())
2626 }
2627 }
2628 };
2629
2630 TokenStream::from(expanded)
2631}
2632
2633fn generate_parsing_arms(
2635 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2636 enum_name: &syn::Ident,
2637) -> proc_macro2::TokenStream {
2638 let mut arms = Vec::new();
2639
2640 for variant in variants {
2641 let variant_name = &variant.ident;
2642 let action_attrs = parse_action_attrs(&variant.attrs);
2643
2644 if let Some(tag) = action_attrs.tag {
2645 match &variant.fields {
2646 syn::Fields::Unit => {
2647 arms.push(quote! {
2649 if &tag_name == #tag {
2650 actions.push(#enum_name::#variant_name);
2651 }
2652 });
2653 }
2654 syn::Fields::Unnamed(_fields) => {
2655 arms.push(quote! {
2657 if &tag_name == #tag && !is_empty {
2658 match reader.read_text(owned_e.name()) {
2660 Ok(text) => {
2661 actions.push(#enum_name::#variant_name(text.to_string()));
2662 }
2663 Err(_) => {
2664 actions.push(#enum_name::#variant_name(String::new()));
2666 }
2667 }
2668 }
2669 });
2670 }
2671 syn::Fields::Named(fields) => {
2672 let mut field_names = Vec::new();
2674 let mut has_inner_text_field = None;
2675
2676 for field in &fields.named {
2677 let field_name = field.ident.as_ref().unwrap();
2678 let field_attrs = parse_field_action_attrs(&field.attrs);
2679
2680 if field_attrs.is_attribute {
2681 field_names.push(field_name.clone());
2682 } else if field_attrs.is_inner_text {
2683 has_inner_text_field = Some(field_name.clone());
2684 }
2685 }
2686
2687 if let Some(inner_text_field) = has_inner_text_field {
2688 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2691 quote! {
2692 let mut #field_name = String::new();
2693 for attr in owned_e.attributes() {
2694 if let Ok(attr) = attr {
2695 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2696 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2697 break;
2698 }
2699 }
2700 }
2701 }
2702 }).collect();
2703
2704 arms.push(quote! {
2705 if &tag_name == #tag {
2706 #(#attr_extractions)*
2707
2708 if is_empty {
2710 let #inner_text_field = String::new();
2711 actions.push(#enum_name::#variant_name {
2712 #(#field_names,)*
2713 #inner_text_field,
2714 });
2715 } else {
2716 match reader.read_text(owned_e.name()) {
2718 Ok(text) => {
2719 let #inner_text_field = text.to_string();
2720 actions.push(#enum_name::#variant_name {
2721 #(#field_names,)*
2722 #inner_text_field,
2723 });
2724 }
2725 Err(_) => {
2726 let #inner_text_field = String::new();
2728 actions.push(#enum_name::#variant_name {
2729 #(#field_names,)*
2730 #inner_text_field,
2731 });
2732 }
2733 }
2734 }
2735 }
2736 });
2737 } else {
2738 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2740 quote! {
2741 let mut #field_name = String::new();
2742 for attr in owned_e.attributes() {
2743 if let Ok(attr) = attr {
2744 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2745 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2746 break;
2747 }
2748 }
2749 }
2750 }
2751 }).collect();
2752
2753 arms.push(quote! {
2754 if &tag_name == #tag {
2755 #(#attr_extractions)*
2756 actions.push(#enum_name::#variant_name {
2757 #(#field_names),*
2758 });
2759 }
2760 });
2761 }
2762 }
2763 }
2764 }
2765 }
2766
2767 quote! {
2768 #(#arms)*
2769 }
2770}
2771
2772#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
2774pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
2775 let input = parse_macro_input!(input as DeriveInput);
2776
2777 let found_crate =
2778 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2779 let crate_path = match found_crate {
2780 FoundCrate::Itself => {
2781 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2783 quote!(::#ident)
2784 }
2785 FoundCrate::Name(name) => {
2786 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2787 quote!(::#ident)
2788 }
2789 };
2790
2791 let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
2793
2794 let struct_name = &input.ident;
2795 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2796
2797 let placeholders = parse_template_placeholders_with_mode(&template);
2799
2800 let mut converted_template = template.clone();
2802 let mut context_fields = Vec::new();
2803
2804 let fields = match &input.data {
2806 Data::Struct(data_struct) => match &data_struct.fields {
2807 syn::Fields::Named(fields) => &fields.named,
2808 _ => panic!("ToPromptFor is only supported for structs with named fields"),
2809 },
2810 _ => panic!("ToPromptFor is only supported for structs"),
2811 };
2812
2813 let has_mode_support = input.attrs.iter().any(|attr| {
2815 if attr.path().is_ident("prompt")
2816 && let Ok(metas) =
2817 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2818 {
2819 for meta in metas {
2820 if let Meta::NameValue(nv) = meta
2821 && nv.path.is_ident("mode")
2822 {
2823 return true;
2824 }
2825 }
2826 }
2827 false
2828 });
2829
2830 for (placeholder_name, mode_opt) in &placeholders {
2832 if placeholder_name == "self" {
2833 if let Some(specific_mode) = mode_opt {
2834 let unique_key = format!("self__{}", specific_mode);
2836
2837 let pattern = format!("{{{{ self:{} }}}}", specific_mode);
2839 let replacement = format!("{{{{ {} }}}}", unique_key);
2840 converted_template = converted_template.replace(&pattern, &replacement);
2841
2842 context_fields.push(quote! {
2844 context.insert(
2845 #unique_key.to_string(),
2846 minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
2847 );
2848 });
2849 } else {
2850 if has_mode_support {
2853 context_fields.push(quote! {
2855 context.insert(
2856 "self".to_string(),
2857 minijinja::Value::from(self.to_prompt_with_mode(mode))
2858 );
2859 });
2860 } else {
2861 context_fields.push(quote! {
2863 context.insert(
2864 "self".to_string(),
2865 minijinja::Value::from(self.to_prompt())
2866 );
2867 });
2868 }
2869 }
2870 } else {
2871 let field_exists = fields.iter().any(|f| {
2874 f.ident
2875 .as_ref()
2876 .is_some_and(|ident| ident == placeholder_name)
2877 });
2878
2879 if field_exists {
2880 let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
2881
2882 context_fields.push(quote! {
2886 context.insert(
2887 #placeholder_name.to_string(),
2888 minijinja::Value::from_serialize(&self.#field_ident)
2889 );
2890 });
2891 }
2892 }
2894 }
2895
2896 let expanded = quote! {
2897 impl #impl_generics #crate_path::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
2898 where
2899 #target_type: serde::Serialize,
2900 {
2901 fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
2902 let mut env = minijinja::Environment::new();
2904 env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
2905 panic!("Failed to parse template: {}", e)
2906 });
2907
2908 let tmpl = env.get_template("prompt").unwrap();
2909
2910 let mut context = std::collections::HashMap::new();
2912 context.insert(
2914 "self".to_string(),
2915 minijinja::Value::from_serialize(self)
2916 );
2917 context.insert(
2919 "target".to_string(),
2920 minijinja::Value::from_serialize(target)
2921 );
2922 #(#context_fields)*
2923
2924 tmpl.render(context).unwrap_or_else(|e| {
2926 format!("Failed to render prompt: {}", e)
2927 })
2928 }
2929 }
2930 };
2931
2932 TokenStream::from(expanded)
2933}
2934
2935struct AgentAttrs {
2941 expertise: Option<String>,
2942 output: Option<syn::Type>,
2943 backend: Option<String>,
2944 model: Option<String>,
2945 inner: Option<String>,
2946 default_inner: Option<String>,
2947 max_retries: Option<u32>,
2948 profile: Option<String>,
2949}
2950
2951impl Parse for AgentAttrs {
2952 fn parse(input: ParseStream) -> syn::Result<Self> {
2953 let mut expertise = None;
2954 let mut output = None;
2955 let mut backend = None;
2956 let mut model = None;
2957 let mut inner = None;
2958 let mut default_inner = None;
2959 let mut max_retries = None;
2960 let mut profile = None;
2961
2962 let pairs = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
2963
2964 for meta in pairs {
2965 match meta {
2966 Meta::NameValue(nv) if nv.path.is_ident("expertise") => {
2967 if let syn::Expr::Lit(syn::ExprLit {
2968 lit: syn::Lit::Str(lit_str),
2969 ..
2970 }) = &nv.value
2971 {
2972 expertise = Some(lit_str.value());
2973 }
2974 }
2975 Meta::NameValue(nv) if nv.path.is_ident("output") => {
2976 if let syn::Expr::Lit(syn::ExprLit {
2977 lit: syn::Lit::Str(lit_str),
2978 ..
2979 }) = &nv.value
2980 {
2981 let ty: syn::Type = syn::parse_str(&lit_str.value())?;
2982 output = Some(ty);
2983 }
2984 }
2985 Meta::NameValue(nv) if nv.path.is_ident("backend") => {
2986 if let syn::Expr::Lit(syn::ExprLit {
2987 lit: syn::Lit::Str(lit_str),
2988 ..
2989 }) = &nv.value
2990 {
2991 backend = Some(lit_str.value());
2992 }
2993 }
2994 Meta::NameValue(nv) if nv.path.is_ident("model") => {
2995 if let syn::Expr::Lit(syn::ExprLit {
2996 lit: syn::Lit::Str(lit_str),
2997 ..
2998 }) = &nv.value
2999 {
3000 model = Some(lit_str.value());
3001 }
3002 }
3003 Meta::NameValue(nv) if nv.path.is_ident("inner") => {
3004 if let syn::Expr::Lit(syn::ExprLit {
3005 lit: syn::Lit::Str(lit_str),
3006 ..
3007 }) = &nv.value
3008 {
3009 inner = Some(lit_str.value());
3010 }
3011 }
3012 Meta::NameValue(nv) if nv.path.is_ident("default_inner") => {
3013 if let syn::Expr::Lit(syn::ExprLit {
3014 lit: syn::Lit::Str(lit_str),
3015 ..
3016 }) = &nv.value
3017 {
3018 default_inner = Some(lit_str.value());
3019 }
3020 }
3021 Meta::NameValue(nv) if nv.path.is_ident("max_retries") => {
3022 if let syn::Expr::Lit(syn::ExprLit {
3023 lit: syn::Lit::Int(lit_int),
3024 ..
3025 }) = &nv.value
3026 {
3027 max_retries = Some(lit_int.base10_parse()?);
3028 }
3029 }
3030 Meta::NameValue(nv) if nv.path.is_ident("profile") => {
3031 if let syn::Expr::Lit(syn::ExprLit {
3032 lit: syn::Lit::Str(lit_str),
3033 ..
3034 }) = &nv.value
3035 {
3036 profile = Some(lit_str.value());
3037 }
3038 }
3039 _ => {}
3040 }
3041 }
3042
3043 Ok(AgentAttrs {
3044 expertise,
3045 output,
3046 backend,
3047 model,
3048 inner,
3049 default_inner,
3050 max_retries,
3051 profile,
3052 })
3053 }
3054}
3055
3056fn parse_agent_attrs(attrs: &[syn::Attribute]) -> syn::Result<AgentAttrs> {
3058 for attr in attrs {
3059 if attr.path().is_ident("agent") {
3060 return attr.parse_args::<AgentAttrs>();
3061 }
3062 }
3063
3064 Ok(AgentAttrs {
3065 expertise: None,
3066 output: None,
3067 backend: None,
3068 model: None,
3069 inner: None,
3070 default_inner: None,
3071 max_retries: None,
3072 profile: None,
3073 })
3074}
3075
3076fn generate_backend_constructors(
3078 struct_name: &syn::Ident,
3079 backend: &str,
3080 _model: Option<&str>,
3081 _profile: Option<&str>,
3082 crate_path: &proc_macro2::TokenStream,
3083) -> proc_macro2::TokenStream {
3084 match backend {
3085 "claude" => {
3086 quote! {
3087 impl #struct_name {
3088 pub fn with_claude() -> Self {
3090 Self::new(#crate_path::agent::impls::ClaudeCodeAgent::new())
3091 }
3092
3093 pub fn with_claude_model(model: &str) -> Self {
3095 Self::new(
3096 #crate_path::agent::impls::ClaudeCodeAgent::new()
3097 .with_model_str(model)
3098 )
3099 }
3100 }
3101 }
3102 }
3103 "gemini" => {
3104 quote! {
3105 impl #struct_name {
3106 pub fn with_gemini() -> Self {
3108 Self::new(#crate_path::agent::impls::GeminiAgent::new())
3109 }
3110
3111 pub fn with_gemini_model(model: &str) -> Self {
3113 Self::new(
3114 #crate_path::agent::impls::GeminiAgent::new()
3115 .with_model_str(model)
3116 )
3117 }
3118 }
3119 }
3120 }
3121 _ => quote! {},
3122 }
3123}
3124
3125fn generate_default_impl(
3127 struct_name: &syn::Ident,
3128 backend: &str,
3129 model: Option<&str>,
3130 profile: Option<&str>,
3131 crate_path: &proc_macro2::TokenStream,
3132) -> proc_macro2::TokenStream {
3133 let profile_expr = if let Some(profile_str) = profile {
3135 match profile_str.to_lowercase().as_str() {
3136 "creative" => quote! { #crate_path::agent::ExecutionProfile::Creative },
3137 "balanced" => quote! { #crate_path::agent::ExecutionProfile::Balanced },
3138 "deterministic" => quote! { #crate_path::agent::ExecutionProfile::Deterministic },
3139 _ => quote! { #crate_path::agent::ExecutionProfile::Balanced }, }
3141 } else {
3142 quote! { #crate_path::agent::ExecutionProfile::default() }
3143 };
3144
3145 let agent_init = match backend {
3146 "gemini" => {
3147 let mut builder = quote! { #crate_path::agent::impls::GeminiAgent::new() };
3148
3149 if let Some(model_str) = model {
3150 builder = quote! { #builder.with_model_str(#model_str) };
3151 }
3152
3153 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3154 builder
3155 }
3156 _ => {
3157 let mut builder = quote! { #crate_path::agent::impls::ClaudeCodeAgent::new() };
3159
3160 if let Some(model_str) = model {
3161 builder = quote! { #builder.with_model_str(#model_str) };
3162 }
3163
3164 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3165 builder
3166 }
3167 };
3168
3169 quote! {
3170 impl Default for #struct_name {
3171 fn default() -> Self {
3172 Self::new(#agent_init)
3173 }
3174 }
3175 }
3176}
3177
3178#[proc_macro_derive(Agent, attributes(agent))]
3187pub fn derive_agent(input: TokenStream) -> TokenStream {
3188 let input = parse_macro_input!(input as DeriveInput);
3189 let struct_name = &input.ident;
3190
3191 let agent_attrs = match parse_agent_attrs(&input.attrs) {
3193 Ok(attrs) => attrs,
3194 Err(e) => return e.to_compile_error().into(),
3195 };
3196
3197 let expertise = agent_attrs
3198 .expertise
3199 .unwrap_or_else(|| String::from("general AI assistant"));
3200 let output_type = agent_attrs
3201 .output
3202 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3203 let backend = agent_attrs
3204 .backend
3205 .unwrap_or_else(|| String::from("claude"));
3206 let model = agent_attrs.model;
3207 let _profile = agent_attrs.profile; let max_retries = agent_attrs.max_retries.unwrap_or(3); let found_crate =
3212 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3213 let crate_path = match found_crate {
3214 FoundCrate::Itself => {
3215 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3217 quote!(::#ident)
3218 }
3219 FoundCrate::Name(name) => {
3220 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3221 quote!(::#ident)
3222 }
3223 };
3224
3225 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
3226
3227 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3229 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3230
3231 let enhanced_expertise = if is_string_output {
3233 quote! { #expertise }
3235 } else {
3236 let type_name = quote!(#output_type).to_string();
3238 quote! {
3239 {
3240 use std::sync::OnceLock;
3241 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3242
3243 EXPERTISE_CACHE.get_or_init(|| {
3244 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3246
3247 if schema.is_empty() {
3248 format!(
3250 concat!(
3251 #expertise,
3252 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3253 "Do not include any text outside the JSON object."
3254 ),
3255 #type_name
3256 )
3257 } else {
3258 format!(
3260 concat!(
3261 #expertise,
3262 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3263 ),
3264 schema
3265 )
3266 }
3267 }).as_str()
3268 }
3269 }
3270 };
3271
3272 let agent_init = match backend.as_str() {
3274 "gemini" => {
3275 if let Some(model_str) = model {
3276 quote! {
3277 use #crate_path::agent::impls::GeminiAgent;
3278 let agent = GeminiAgent::new().with_model_str(#model_str);
3279 }
3280 } else {
3281 quote! {
3282 use #crate_path::agent::impls::GeminiAgent;
3283 let agent = GeminiAgent::new();
3284 }
3285 }
3286 }
3287 "claude" => {
3288 if let Some(model_str) = model {
3289 quote! {
3290 use #crate_path::agent::impls::ClaudeCodeAgent;
3291 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3292 }
3293 } else {
3294 quote! {
3295 use #crate_path::agent::impls::ClaudeCodeAgent;
3296 let agent = ClaudeCodeAgent::new();
3297 }
3298 }
3299 }
3300 _ => {
3301 if let Some(model_str) = model {
3303 quote! {
3304 use #crate_path::agent::impls::ClaudeCodeAgent;
3305 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3306 }
3307 } else {
3308 quote! {
3309 use #crate_path::agent::impls::ClaudeCodeAgent;
3310 let agent = ClaudeCodeAgent::new();
3311 }
3312 }
3313 }
3314 };
3315
3316 let expanded = quote! {
3317 #[async_trait::async_trait]
3318 impl #impl_generics #crate_path::agent::Agent for #struct_name #ty_generics #where_clause {
3319 type Output = #output_type;
3320
3321 fn expertise(&self) -> &str {
3322 #enhanced_expertise
3323 }
3324
3325 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3326 #agent_init
3328
3329 let max_retries: u32 = #max_retries;
3331 let mut attempts = 0u32;
3332
3333 loop {
3334 attempts += 1;
3335
3336 let result = async {
3338 let response = agent.execute(intent.clone()).await?;
3339
3340 let json_str = #crate_path::extract_json(&response)
3342 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))?;
3343
3344 serde_json::from_str::<Self::Output>(&json_str)
3346 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))
3347 }.await;
3348
3349 match result {
3350 Ok(output) => return Ok(output),
3351 Err(e) if e.is_retryable() && attempts < max_retries => {
3352 log::warn!(
3354 "Agent execution failed (attempt {}/{}): {}. Retrying...",
3355 attempts,
3356 max_retries,
3357 e
3358 );
3359
3360 let delay_ms = 100 * attempts as u64;
3362 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
3363
3364 continue;
3366 }
3367 Err(e) => {
3368 if attempts > 1 {
3369 log::error!(
3370 "Agent execution failed after {} attempts: {}",
3371 attempts,
3372 e
3373 );
3374 }
3375 return Err(e);
3376 }
3377 }
3378 }
3379 }
3380
3381 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3382 #agent_init
3384 agent.is_available().await
3385 }
3386 }
3387 };
3388
3389 TokenStream::from(expanded)
3390}
3391
3392#[proc_macro_attribute]
3407pub fn agent(attr: TokenStream, item: TokenStream) -> TokenStream {
3408 let agent_attrs = match syn::parse::<AgentAttrs>(attr) {
3410 Ok(attrs) => attrs,
3411 Err(e) => return e.to_compile_error().into(),
3412 };
3413
3414 let input = parse_macro_input!(item as DeriveInput);
3416 let struct_name = &input.ident;
3417 let vis = &input.vis;
3418
3419 let expertise = agent_attrs
3420 .expertise
3421 .unwrap_or_else(|| String::from("general AI assistant"));
3422 let output_type = agent_attrs
3423 .output
3424 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3425 let backend = agent_attrs
3426 .backend
3427 .unwrap_or_else(|| String::from("claude"));
3428 let model = agent_attrs.model;
3429 let profile = agent_attrs.profile;
3430
3431 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3433 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3434
3435 let found_crate =
3437 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3438 let crate_path = match found_crate {
3439 FoundCrate::Itself => {
3440 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3441 quote!(::#ident)
3442 }
3443 FoundCrate::Name(name) => {
3444 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3445 quote!(::#ident)
3446 }
3447 };
3448
3449 let inner_generic_name = agent_attrs.inner.unwrap_or_else(|| String::from("A"));
3451 let inner_generic_ident = syn::Ident::new(&inner_generic_name, proc_macro2::Span::call_site());
3452
3453 let default_agent_type = if let Some(ref custom_type) = agent_attrs.default_inner {
3455 let type_path: syn::Type =
3457 syn::parse_str(custom_type).expect("default_inner must be a valid type path");
3458 quote! { #type_path }
3459 } else {
3460 match backend.as_str() {
3462 "gemini" => quote! { #crate_path::agent::impls::GeminiAgent },
3463 _ => quote! { #crate_path::agent::impls::ClaudeCodeAgent },
3464 }
3465 };
3466
3467 let struct_def = quote! {
3469 #vis struct #struct_name<#inner_generic_ident = #default_agent_type> {
3470 inner: #inner_generic_ident,
3471 }
3472 };
3473
3474 let constructors = quote! {
3476 impl<#inner_generic_ident> #struct_name<#inner_generic_ident> {
3477 pub fn new(inner: #inner_generic_ident) -> Self {
3479 Self { inner }
3480 }
3481 }
3482 };
3483
3484 let (backend_constructors, default_impl) = if agent_attrs.default_inner.is_some() {
3486 let default_impl = quote! {
3488 impl Default for #struct_name {
3489 fn default() -> Self {
3490 Self {
3491 inner: <#default_agent_type as Default>::default(),
3492 }
3493 }
3494 }
3495 };
3496 (quote! {}, default_impl)
3497 } else {
3498 let backend_constructors = generate_backend_constructors(
3500 struct_name,
3501 &backend,
3502 model.as_deref(),
3503 profile.as_deref(),
3504 &crate_path,
3505 );
3506 let default_impl = generate_default_impl(
3507 struct_name,
3508 &backend,
3509 model.as_deref(),
3510 profile.as_deref(),
3511 &crate_path,
3512 );
3513 (backend_constructors, default_impl)
3514 };
3515
3516 let enhanced_expertise = if is_string_output {
3518 quote! { #expertise }
3520 } else {
3521 let type_name = quote!(#output_type).to_string();
3523 quote! {
3524 {
3525 use std::sync::OnceLock;
3526 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3527
3528 EXPERTISE_CACHE.get_or_init(|| {
3529 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3531
3532 if schema.is_empty() {
3533 format!(
3535 concat!(
3536 #expertise,
3537 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3538 "Do not include any text outside the JSON object."
3539 ),
3540 #type_name
3541 )
3542 } else {
3543 format!(
3545 concat!(
3546 #expertise,
3547 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3548 ),
3549 schema
3550 )
3551 }
3552 }).as_str()
3553 }
3554 }
3555 };
3556
3557 let agent_impl = quote! {
3559 #[async_trait::async_trait]
3560 impl<#inner_generic_ident> #crate_path::agent::Agent for #struct_name<#inner_generic_ident>
3561 where
3562 #inner_generic_ident: #crate_path::agent::Agent<Output = String>,
3563 {
3564 type Output = #output_type;
3565
3566 fn expertise(&self) -> &str {
3567 #enhanced_expertise
3568 }
3569
3570 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3571 let enhanced_payload = intent.prepend_text(self.expertise());
3573
3574 let response = self.inner.execute(enhanced_payload).await?;
3576
3577 let json_str = #crate_path::extract_json(&response)
3579 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))?;
3580
3581 serde_json::from_str(&json_str)
3583 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))
3584 }
3585
3586 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3587 self.inner.is_available().await
3588 }
3589 }
3590 };
3591
3592 let expanded = quote! {
3593 #struct_def
3594 #constructors
3595 #backend_constructors
3596 #default_impl
3597 #agent_impl
3598 };
3599
3600 TokenStream::from(expanded)
3601}
3602
3603#[proc_macro_derive(TypeMarker)]
3625pub fn derive_type_marker(input: TokenStream) -> TokenStream {
3626 let input = parse_macro_input!(input as DeriveInput);
3627 let struct_name = &input.ident;
3628 let type_name_str = struct_name.to_string();
3629
3630 let found_crate =
3632 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3633 let crate_path = match found_crate {
3634 FoundCrate::Itself => {
3635 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3636 quote!(::#ident)
3637 }
3638 FoundCrate::Name(name) => {
3639 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3640 quote!(::#ident)
3641 }
3642 };
3643
3644 let expanded = quote! {
3645 impl #crate_path::orchestrator::TypeMarker for #struct_name {
3646 const TYPE_NAME: &'static str = #type_name_str;
3647 }
3648 };
3649
3650 TokenStream::from(expanded)
3651}
3652
3653#[proc_macro_attribute]
3689pub fn type_marker(_attr: TokenStream, item: TokenStream) -> TokenStream {
3690 let input = parse_macro_input!(item as syn::DeriveInput);
3691 let struct_name = &input.ident;
3692 let vis = &input.vis;
3693 let type_name_str = struct_name.to_string();
3694
3695 let default_fn_name = syn::Ident::new(
3697 &format!("default_{}_type", to_snake_case(&type_name_str)),
3698 struct_name.span(),
3699 );
3700
3701 let found_crate =
3703 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3704 let crate_path = match found_crate {
3705 FoundCrate::Itself => {
3706 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3707 quote!(::#ident)
3708 }
3709 FoundCrate::Name(name) => {
3710 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3711 quote!(::#ident)
3712 }
3713 };
3714
3715 let fields = match &input.data {
3717 syn::Data::Struct(data_struct) => match &data_struct.fields {
3718 syn::Fields::Named(fields) => &fields.named,
3719 _ => {
3720 return syn::Error::new_spanned(
3721 struct_name,
3722 "type_marker only works with structs with named fields",
3723 )
3724 .to_compile_error()
3725 .into();
3726 }
3727 },
3728 _ => {
3729 return syn::Error::new_spanned(struct_name, "type_marker only works with structs")
3730 .to_compile_error()
3731 .into();
3732 }
3733 };
3734
3735 let mut new_fields = vec![];
3737
3738 let default_fn_name_str = default_fn_name.to_string();
3740 let default_fn_name_lit = syn::LitStr::new(&default_fn_name_str, default_fn_name.span());
3741
3742 new_fields.push(quote! {
3747 #[serde(default = #default_fn_name_lit)]
3748 __type: String
3749 });
3750
3751 for field in fields {
3753 new_fields.push(quote! { #field });
3754 }
3755
3756 let attrs = &input.attrs;
3758 let generics = &input.generics;
3759
3760 let expanded = quote! {
3761 fn #default_fn_name() -> String {
3763 #type_name_str.to_string()
3764 }
3765
3766 #(#attrs)*
3768 #vis struct #struct_name #generics {
3769 #(#new_fields),*
3770 }
3771
3772 impl #crate_path::orchestrator::TypeMarker for #struct_name {
3774 const TYPE_NAME: &'static str = #type_name_str;
3775 }
3776 };
3777
3778 TokenStream::from(expanded)
3779}