1use proc_macro::TokenStream;
2use proc_macro_crate::{FoundCrate, crate_name};
3use quote::quote;
4use regex::Regex;
5use syn::{
6 Data, DeriveInput, Meta, Token,
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 punctuated::Punctuated,
10};
11
12fn parse_template_placeholders_with_mode(template: &str) -> Vec<(String, Option<String>)> {
15 let mut placeholders = Vec::new();
16 let mut seen_fields = std::collections::HashSet::new();
17
18 let mode_pattern = Regex::new(r"\{\{\s*(\w+)\s*:\s*(\w+)\s*\}\}").unwrap();
20 for cap in mode_pattern.captures_iter(template) {
21 let field_name = cap[1].to_string();
22 let mode = cap[2].to_string();
23 placeholders.push((field_name.clone(), Some(mode)));
24 seen_fields.insert(field_name);
25 }
26
27 let standard_pattern = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
29 for cap in standard_pattern.captures_iter(template) {
30 let field_name = cap[1].to_string();
31 if !seen_fields.contains(&field_name) {
33 placeholders.push((field_name, None));
34 }
35 }
36
37 placeholders
38}
39
40fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
42 attrs
43 .iter()
44 .filter_map(|attr| {
45 if attr.path().is_ident("doc")
46 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
47 && let syn::Expr::Lit(syn::ExprLit {
48 lit: syn::Lit::Str(lit_str),
49 ..
50 }) = &meta_name_value.value
51 {
52 return Some(lit_str.value());
53 }
54 None
55 })
56 .map(|s| s.trim().to_string())
57 .collect::<Vec<_>>()
58 .join(" ")
59}
60
61fn generate_example_only_parts(
63 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
64 has_default: bool,
65 crate_path: &proc_macro2::TokenStream,
66) -> proc_macro2::TokenStream {
67 let mut field_values = Vec::new();
68
69 for field in fields.iter() {
70 let field_name = field.ident.as_ref().unwrap();
71 let field_name_str = field_name.to_string();
72 let attrs = parse_field_prompt_attrs(&field.attrs);
73
74 if field_name_str == "__type" {
78 continue;
79 }
80
81 if attrs.skip {
83 continue;
84 }
85
86 if let Some(example) = attrs.example {
88 field_values.push(quote! {
90 json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
91 });
92 } else if has_default {
93 field_values.push(quote! {
95 let default_value = serde_json::to_value(&default_instance.#field_name)
96 .unwrap_or(serde_json::Value::Null);
97 json_obj.insert(#field_name_str.to_string(), default_value);
98 });
99 } else {
100 field_values.push(quote! {
102 let value = serde_json::to_value(&self.#field_name)
103 .unwrap_or(serde_json::Value::Null);
104 json_obj.insert(#field_name_str.to_string(), value);
105 });
106 }
107 }
108
109 if has_default {
110 quote! {
111 {
112 let default_instance = Self::default();
113 let mut json_obj = serde_json::Map::new();
114 #(#field_values)*
115 let json_value = serde_json::Value::Object(json_obj);
116 let json_str = serde_json::to_string_pretty(&json_value)
117 .unwrap_or_else(|_| "{}".to_string());
118 vec![#crate_path::prompt::PromptPart::Text(json_str)]
119 }
120 }
121 } else {
122 quote! {
123 {
124 let mut json_obj = serde_json::Map::new();
125 #(#field_values)*
126 let json_value = serde_json::Value::Object(json_obj);
127 let json_str = serde_json::to_string_pretty(&json_value)
128 .unwrap_or_else(|_| "{}".to_string());
129 vec![#crate_path::prompt::PromptPart::Text(json_str)]
130 }
131 }
132 }
133}
134
135fn generate_schema_only_parts(
137 struct_name: &str,
138 struct_docs: &str,
139 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
140 crate_path: &proc_macro2::TokenStream,
141 _has_type_marker: bool,
142) -> proc_macro2::TokenStream {
143 let mut field_schema_parts = vec![];
144
145 for (i, field) in fields.iter().enumerate() {
147 let field_name = field.ident.as_ref().unwrap();
148 let field_name_str = field_name.to_string();
149 let attrs = parse_field_prompt_attrs(&field.attrs);
150
151 if field_name_str == "__type" {
155 continue;
156 }
157
158 if attrs.skip {
160 continue;
161 }
162
163 let field_docs = extract_doc_comments(&field.attrs);
165
166 let (is_vec, inner_type) = extract_vec_inner_type(&field.ty);
168
169 let remaining_fields = fields
171 .iter()
172 .skip(i + 1)
173 .filter(|f| {
174 let attrs = parse_field_prompt_attrs(&f.attrs);
175 !attrs.skip
176 })
177 .count();
178 let comma = if remaining_fields > 0 { "," } else { "" };
179
180 if is_vec {
181 let comment = if !field_docs.is_empty() {
183 format!(", // {}", field_docs)
184 } else {
185 String::new()
186 };
187
188 field_schema_parts.push(quote! {
189 {
190 let inner_schema = <#inner_type as #crate_path::prompt::ToPrompt>::prompt_schema();
191 if inner_schema.is_empty() {
192 format!(" \"{}\": \"{}[]\"{}{}", #field_name_str, stringify!(#inner_type).to_lowercase(), #comment, #comma)
194 } else {
195 let inner_lines: Vec<&str> = inner_schema.lines()
197 .skip_while(|line| line.starts_with("###") || line.trim() == "{")
198 .take_while(|line| line.trim() != "}")
199 .collect();
200 let inner_content = inner_lines.join("\n");
201 format!(" \"{}\": [\n {{\n{}\n }}\n ]{}{}",
202 #field_name_str,
203 inner_content.lines()
204 .map(|line| format!(" {}", line))
205 .collect::<Vec<_>>()
206 .join("\n"),
207 #comment,
208 #comma
209 )
210 }
211 }
212 });
213 } else {
214 let field_type = &field.ty;
216 let is_primitive = is_primitive_type(field_type);
217
218 if !is_primitive {
219 let comment = if !field_docs.is_empty() {
221 format!(", // {}", field_docs)
222 } else {
223 String::new()
224 };
225
226 field_schema_parts.push(quote! {
227 {
228 let nested_schema = <#field_type as #crate_path::prompt::ToPrompt>::prompt_schema();
229 if nested_schema.is_empty() {
230 let type_str = stringify!(#field_type).to_lowercase();
232 format!(" \"{}\": \"{}\"{}{}", #field_name_str, type_str, #comment, #comma)
233 } else {
234 let nested_lines: Vec<&str> = nested_schema.lines()
236 .skip_while(|line| line.starts_with("###") || line.trim() == "{")
237 .take_while(|line| line.trim() != "}")
238 .collect();
239
240 if nested_lines.is_empty() {
241 let type_str = stringify!(#field_type).to_lowercase();
243 format!(" \"{}\": \"{}\"{}{}", #field_name_str, type_str, #comment, #comma)
244 } else {
245 let indented_content = nested_lines.iter()
247 .map(|line| format!(" {}", line))
248 .collect::<Vec<_>>()
249 .join("\n");
250 format!(" \"{}\": {{\n{}\n }}{}{}", #field_name_str, indented_content, #comment, #comma)
251 }
252 }
253 }
254 });
255 } else {
256 let type_str = format_type_for_schema(&field.ty);
258 let comment = if !field_docs.is_empty() {
259 format!(", // {}", field_docs)
260 } else {
261 String::new()
262 };
263
264 field_schema_parts.push(quote! {
265 format!(" \"{}\": \"{}\"{}{}", #field_name_str, #type_str, #comment, #comma)
266 });
267 }
268 }
269 }
270
271 let header = if !struct_docs.is_empty() {
273 format!("### Schema for `{}`\n{}", struct_name, struct_docs)
274 } else {
275 format!("### Schema for `{}`", struct_name)
276 };
277
278 quote! {
279 {
280 let mut lines = vec![#header.to_string(), "{".to_string()];
281 #(lines.push(#field_schema_parts);)*
282 lines.push("}".to_string());
283 vec![#crate_path::prompt::PromptPart::Text(lines.join("\n"))]
284 }
285 }
286}
287
288fn extract_vec_inner_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
290 if let syn::Type::Path(type_path) = ty
291 && let Some(last_segment) = type_path.path.segments.last()
292 && last_segment.ident == "Vec"
293 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
294 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
295 {
296 return (true, Some(inner_type));
297 }
298 (false, None)
299}
300
301fn is_primitive_type(ty: &syn::Type) -> bool {
303 if let syn::Type::Path(type_path) = ty
304 && let Some(last_segment) = type_path.path.segments.last()
305 {
306 let type_name = last_segment.ident.to_string();
307 matches!(
308 type_name.as_str(),
309 "String"
310 | "str"
311 | "i8"
312 | "i16"
313 | "i32"
314 | "i64"
315 | "i128"
316 | "isize"
317 | "u8"
318 | "u16"
319 | "u32"
320 | "u64"
321 | "u128"
322 | "usize"
323 | "f32"
324 | "f64"
325 | "bool"
326 | "Vec"
327 | "Option"
328 | "HashMap"
329 | "BTreeMap"
330 | "HashSet"
331 | "BTreeSet"
332 )
333 } else {
334 true
336 }
337}
338
339fn format_type_for_schema(ty: &syn::Type) -> String {
341 match ty {
343 syn::Type::Path(type_path) => {
344 let path = &type_path.path;
345 if let Some(last_segment) = path.segments.last() {
346 let type_name = last_segment.ident.to_string();
347
348 if type_name == "Option"
350 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
351 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
352 {
353 return format!("{} | null", format_type_for_schema(inner_type));
354 }
355
356 match type_name.as_str() {
358 "String" | "str" => "string".to_string(),
359 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
360 | "u64" | "u128" | "usize" => "number".to_string(),
361 "f32" | "f64" => "number".to_string(),
362 "bool" => "boolean".to_string(),
363 "Vec" => {
364 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
365 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
366 {
367 return format!("{}[]", format_type_for_schema(inner_type));
368 }
369 "array".to_string()
370 }
371 _ => type_name.to_lowercase(),
372 }
373 } else {
374 "unknown".to_string()
375 }
376 }
377 _ => "unknown".to_string(),
378 }
379}
380
381enum PromptAttribute {
383 Skip,
384 Description(String),
385 None,
386}
387
388fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
390 for attr in attrs {
391 if attr.path().is_ident("prompt") {
392 if let Ok(meta_list) = attr.meta.require_list() {
394 let tokens = &meta_list.tokens;
395 let tokens_str = tokens.to_string();
396 if tokens_str == "skip" {
397 return PromptAttribute::Skip;
398 }
399 }
400
401 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
403 return PromptAttribute::Description(lit_str.value());
404 }
405 }
406 }
407 PromptAttribute::None
408}
409
410#[derive(Debug, Default)]
412struct FieldPromptAttrs {
413 skip: bool,
414 rename: Option<String>,
415 format_with: Option<String>,
416 image: bool,
417 example: Option<String>,
418}
419
420fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
422 let mut result = FieldPromptAttrs::default();
423
424 for attr in attrs {
425 if attr.path().is_ident("prompt") {
426 if let Ok(meta_list) = attr.meta.require_list() {
428 if let Ok(metas) =
430 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
431 {
432 for meta in metas {
433 match meta {
434 Meta::Path(path) if path.is_ident("skip") => {
435 result.skip = true;
436 }
437 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
438 if let syn::Expr::Lit(syn::ExprLit {
439 lit: syn::Lit::Str(lit_str),
440 ..
441 }) = nv.value
442 {
443 result.rename = Some(lit_str.value());
444 }
445 }
446 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
447 if let syn::Expr::Lit(syn::ExprLit {
448 lit: syn::Lit::Str(lit_str),
449 ..
450 }) = nv.value
451 {
452 result.format_with = Some(lit_str.value());
453 }
454 }
455 Meta::Path(path) if path.is_ident("image") => {
456 result.image = true;
457 }
458 Meta::NameValue(nv) if nv.path.is_ident("example") => {
459 if let syn::Expr::Lit(syn::ExprLit {
460 lit: syn::Lit::Str(lit_str),
461 ..
462 }) = nv.value
463 {
464 result.example = Some(lit_str.value());
465 }
466 }
467 _ => {}
468 }
469 }
470 } else if meta_list.tokens.to_string() == "skip" {
471 result.skip = true;
473 } else if meta_list.tokens.to_string() == "image" {
474 result.image = true;
476 }
477 }
478 }
479 }
480
481 result
482}
483
484#[proc_macro_derive(ToPrompt, attributes(prompt))]
527pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
528 let input = parse_macro_input!(input as DeriveInput);
529
530 let found_crate =
531 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
532 let crate_path = match found_crate {
533 FoundCrate::Itself => {
534 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
536 quote!(::#ident)
537 }
538 FoundCrate::Name(name) => {
539 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
540 quote!(::#ident)
541 }
542 };
543
544 match &input.data {
546 Data::Enum(data_enum) => {
547 let enum_name = &input.ident;
549 let enum_docs = extract_doc_comments(&input.attrs);
550
551 let mut variant_names = Vec::new();
553 for variant in &data_enum.variants {
554 let variant_name = &variant.ident;
555
556 match parse_prompt_attribute(&variant.attrs) {
558 PromptAttribute::Skip => continue,
559 _ => variant_names.push(variant_name.to_string()),
560 }
561 }
562
563 let variants_list = variant_names.join(" | ");
567 let prompt_string = if !enum_docs.is_empty() {
568 format!("{}: {} (enum: {})", enum_name, enum_docs, variants_list)
569 } else {
570 format!("{} (enum: {})", enum_name, variants_list)
571 };
572 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
573
574 let mut match_arms = Vec::new();
576 for variant in &data_enum.variants {
577 let variant_name = &variant.ident;
578
579 match parse_prompt_attribute(&variant.attrs) {
581 PromptAttribute::Skip => {
582 match_arms.push(quote! {
584 Self::#variant_name => stringify!(#variant_name).to_string()
585 });
586 }
587 PromptAttribute::Description(desc) => {
588 match_arms.push(quote! {
590 Self::#variant_name => format!("{}: {}", stringify!(#variant_name), #desc)
591 });
592 }
593 PromptAttribute::None => {
594 let variant_docs = extract_doc_comments(&variant.attrs);
596 if !variant_docs.is_empty() {
597 match_arms.push(quote! {
598 Self::#variant_name => format!("{}: {}", stringify!(#variant_name), #variant_docs)
599 });
600 } else {
601 match_arms.push(quote! {
602 Self::#variant_name => stringify!(#variant_name).to_string()
603 });
604 }
605 }
606 }
607 }
608
609 let to_prompt_impl = if match_arms.is_empty() {
610 quote! {
612 fn to_prompt(&self) -> String {
613 match *self {}
614 }
615 }
616 } else {
617 quote! {
618 fn to_prompt(&self) -> String {
619 match self {
620 #(#match_arms),*
621 }
622 }
623 }
624 };
625
626 let expanded = quote! {
627 impl #impl_generics #crate_path::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
628 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
629 vec![#crate_path::prompt::PromptPart::Text(self.to_prompt())]
630 }
631
632 #to_prompt_impl
633
634 fn prompt_schema() -> String {
635 #prompt_string.to_string()
636 }
637 }
638 };
639
640 TokenStream::from(expanded)
641 }
642 Data::Struct(data_struct) => {
643 let mut template_attr = None;
645 let mut template_file_attr = None;
646 let mut mode_attr = None;
647 let mut validate_attr = false;
648 let mut type_marker_attr = false;
649
650 for attr in &input.attrs {
651 if attr.path().is_ident("prompt") {
652 if let Ok(metas) =
654 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
655 {
656 for meta in metas {
657 match meta {
658 Meta::NameValue(nv) if nv.path.is_ident("template") => {
659 if let syn::Expr::Lit(expr_lit) = nv.value
660 && let syn::Lit::Str(lit_str) = expr_lit.lit
661 {
662 template_attr = Some(lit_str.value());
663 }
664 }
665 Meta::NameValue(nv) if nv.path.is_ident("template_file") => {
666 if let syn::Expr::Lit(expr_lit) = nv.value
667 && let syn::Lit::Str(lit_str) = expr_lit.lit
668 {
669 template_file_attr = Some(lit_str.value());
670 }
671 }
672 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
673 if let syn::Expr::Lit(expr_lit) = nv.value
674 && let syn::Lit::Str(lit_str) = expr_lit.lit
675 {
676 mode_attr = Some(lit_str.value());
677 }
678 }
679 Meta::NameValue(nv) if nv.path.is_ident("validate") => {
680 if let syn::Expr::Lit(expr_lit) = nv.value
681 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
682 {
683 validate_attr = lit_bool.value();
684 }
685 }
686 Meta::NameValue(nv) if nv.path.is_ident("type_marker") => {
687 if let syn::Expr::Lit(expr_lit) = nv.value
688 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
689 {
690 type_marker_attr = lit_bool.value();
691 }
692 }
693 Meta::Path(path) if path.is_ident("type_marker") => {
694 type_marker_attr = true;
696 }
697 _ => {}
698 }
699 }
700 }
701 }
702 }
703
704 if template_attr.is_some() && template_file_attr.is_some() {
706 return syn::Error::new(
707 input.ident.span(),
708 "The `template` and `template_file` attributes are mutually exclusive. Please use only one.",
709 ).to_compile_error().into();
710 }
711
712 let template_str = if let Some(file_path) = template_file_attr {
714 let mut full_path = None;
718
719 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
721 let is_trybuild = manifest_dir.contains("target/tests/trybuild");
723
724 if !is_trybuild {
725 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
727 if candidate.exists() {
728 full_path = Some(candidate);
729 }
730 } else {
731 if let Some(target_pos) = manifest_dir.find("/target/tests/trybuild") {
737 let workspace_root = &manifest_dir[..target_pos];
738 let original_macros_dir = std::path::Path::new(workspace_root)
740 .join("crates")
741 .join("llm-toolkit-macros");
742
743 let candidate = original_macros_dir.join(&file_path);
744 if candidate.exists() {
745 full_path = Some(candidate);
746 }
747 }
748 }
749 }
750
751 if full_path.is_none() {
753 let candidate = std::path::Path::new(&file_path).to_path_buf();
754 if candidate.exists() {
755 full_path = Some(candidate);
756 }
757 }
758
759 if full_path.is_none()
762 && let Ok(current_dir) = std::env::current_dir()
763 {
764 let mut search_dir = current_dir.as_path();
765 for _ in 0..10 {
767 let macros_dir = search_dir.join("crates/llm-toolkit-macros");
769 if macros_dir.exists() {
770 let candidate = macros_dir.join(&file_path);
771 if candidate.exists() {
772 full_path = Some(candidate);
773 break;
774 }
775 }
776 let candidate = search_dir.join(&file_path);
778 if candidate.exists() {
779 full_path = Some(candidate);
780 break;
781 }
782 if let Some(parent) = search_dir.parent() {
783 search_dir = parent;
784 } else {
785 break;
786 }
787 }
788 }
789
790 if full_path.is_none() {
792 let mut error_msg = format!(
794 "Template file '{}' not found at compile time.\n\nSearched in:",
795 file_path
796 );
797
798 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
799 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
800 error_msg.push_str(&format!("\n - {}", candidate.display()));
801 }
802
803 if let Ok(current_dir) = std::env::current_dir() {
804 let candidate = current_dir.join(&file_path);
805 error_msg.push_str(&format!("\n - {}", candidate.display()));
806 }
807
808 error_msg.push_str("\n\nPlease ensure:");
809 error_msg.push_str("\n 1. The template file exists");
810 error_msg.push_str("\n 2. The path is relative to CARGO_MANIFEST_DIR");
811 error_msg.push_str("\n 3. There are no typos in the path");
812
813 return syn::Error::new(input.ident.span(), error_msg)
814 .to_compile_error()
815 .into();
816 }
817
818 let final_path = full_path.unwrap();
819
820 match std::fs::read_to_string(&final_path) {
822 Ok(content) => Some(content),
823 Err(e) => {
824 return syn::Error::new(
825 input.ident.span(),
826 format!(
827 "Failed to read template file '{}': {}\n\nPath resolved to: {}",
828 file_path,
829 e,
830 final_path.display()
831 ),
832 )
833 .to_compile_error()
834 .into();
835 }
836 }
837 } else {
838 template_attr
839 };
840
841 if validate_attr && let Some(template) = &template_str {
843 let mut env = minijinja::Environment::new();
845 if let Err(e) = env.add_template("validation", template) {
846 let warning_msg =
848 format!("Template validation warning: Invalid Jinja syntax - {}", e);
849 let warning_ident = syn::Ident::new(
850 "TEMPLATE_VALIDATION_WARNING",
851 proc_macro2::Span::call_site(),
852 );
853 let _warning_tokens = quote! {
854 #[deprecated(note = #warning_msg)]
855 const #warning_ident: () = ();
856 let _ = #warning_ident;
857 };
858 eprintln!("cargo:warning={}", warning_msg);
860 }
861
862 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
864 &fields.named
865 } else {
866 panic!("Template validation is only supported for structs with named fields.");
867 };
868
869 let field_names: std::collections::HashSet<String> = fields
870 .iter()
871 .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
872 .collect();
873
874 let placeholders = parse_template_placeholders_with_mode(template);
876
877 for (placeholder_name, _mode) in &placeholders {
878 if placeholder_name != "self" && !field_names.contains(placeholder_name) {
879 let warning_msg = format!(
880 "Template validation warning: Variable '{}' used in template but not found in struct fields",
881 placeholder_name
882 );
883 eprintln!("cargo:warning={}", warning_msg);
884 }
885 }
886 }
887
888 let name = input.ident;
889 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
890
891 let struct_docs = extract_doc_comments(&input.attrs);
893
894 let is_mode_based =
896 mode_attr.is_some() || (template_str.is_none() && struct_docs.contains("mode"));
897
898 let expanded = if is_mode_based || mode_attr.is_some() {
899 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
901 &fields.named
902 } else {
903 panic!(
904 "Mode-based prompt generation is only supported for structs with named fields."
905 );
906 };
907
908 let struct_name_str = name.to_string();
909
910 let has_default = input.attrs.iter().any(|attr| {
912 if attr.path().is_ident("derive")
913 && let Ok(meta_list) = attr.meta.require_list()
914 {
915 let tokens_str = meta_list.tokens.to_string();
916 tokens_str.contains("Default")
917 } else {
918 false
919 }
920 });
921
922 let schema_parts = generate_schema_only_parts(
933 &struct_name_str,
934 &struct_docs,
935 fields,
936 &crate_path,
937 type_marker_attr,
938 );
939
940 let example_parts = generate_example_only_parts(fields, has_default, &crate_path);
942
943 quote! {
944 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
945 fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<#crate_path::prompt::PromptPart> {
946 match mode {
947 "schema_only" => #schema_parts,
948 "example_only" => #example_parts,
949 "full" | _ => {
950 let mut parts = Vec::new();
952
953 let schema_parts = #schema_parts;
955 parts.extend(schema_parts);
956
957 parts.push(#crate_path::prompt::PromptPart::Text("\n### Example".to_string()));
959 parts.push(#crate_path::prompt::PromptPart::Text(
960 format!("Here is an example of a valid `{}` object:", #struct_name_str)
961 ));
962
963 let example_parts = #example_parts;
965 parts.extend(example_parts);
966
967 parts
968 }
969 }
970 }
971
972 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
973 self.to_prompt_parts_with_mode("full")
974 }
975
976 fn to_prompt(&self) -> String {
977 self.to_prompt_parts()
978 .into_iter()
979 .filter_map(|part| match part {
980 #crate_path::prompt::PromptPart::Text(text) => Some(text),
981 _ => None,
982 })
983 .collect::<Vec<_>>()
984 .join("\n")
985 }
986
987 fn prompt_schema() -> String {
988 use std::sync::OnceLock;
989 static SCHEMA_CACHE: OnceLock<String> = OnceLock::new();
990
991 SCHEMA_CACHE.get_or_init(|| {
992 let schema_parts = #schema_parts;
993 schema_parts
994 .into_iter()
995 .filter_map(|part| match part {
996 #crate_path::prompt::PromptPart::Text(text) => Some(text),
997 _ => None,
998 })
999 .collect::<Vec<_>>()
1000 .join("\n")
1001 }).clone()
1002 }
1003 }
1004 }
1005 } else if let Some(template) = template_str {
1006 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1009 &fields.named
1010 } else {
1011 panic!(
1012 "Template prompt generation is only supported for structs with named fields."
1013 );
1014 };
1015
1016 let placeholders = parse_template_placeholders_with_mode(&template);
1018 let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
1020 mode.is_some()
1021 && fields
1022 .iter()
1023 .any(|f| f.ident.as_ref().unwrap() == field_name)
1024 });
1025
1026 let mut image_field_parts = Vec::new();
1027 for f in fields.iter() {
1028 let field_name = f.ident.as_ref().unwrap();
1029 let attrs = parse_field_prompt_attrs(&f.attrs);
1030
1031 if attrs.image {
1032 image_field_parts.push(quote! {
1034 parts.extend(self.#field_name.to_prompt_parts());
1035 });
1036 }
1037 }
1038
1039 if has_mode_syntax {
1041 let mut context_fields = Vec::new();
1043 let mut modified_template = template.clone();
1044
1045 for (field_name, mode_opt) in &placeholders {
1047 if let Some(mode) = mode_opt {
1048 let unique_key = format!("{}__{}", field_name, mode);
1050
1051 let pattern = format!("{{{{ {}:{} }}}}", field_name, mode);
1053 let replacement = format!("{{{{ {} }}}}", unique_key);
1054 modified_template = modified_template.replace(&pattern, &replacement);
1055
1056 let field_ident =
1058 syn::Ident::new(field_name, proc_macro2::Span::call_site());
1059
1060 context_fields.push(quote! {
1062 context.insert(
1063 #unique_key.to_string(),
1064 minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
1065 );
1066 });
1067 }
1068 }
1069
1070 for field in fields.iter() {
1072 let field_name = field.ident.as_ref().unwrap();
1073 let field_name_str = field_name.to_string();
1074
1075 let has_mode_entry = placeholders
1077 .iter()
1078 .any(|(name, mode)| name == &field_name_str && mode.is_some());
1079
1080 if !has_mode_entry {
1081 let is_primitive = match &field.ty {
1084 syn::Type::Path(type_path) => {
1085 if let Some(segment) = type_path.path.segments.last() {
1086 let type_name = segment.ident.to_string();
1087 matches!(
1088 type_name.as_str(),
1089 "String"
1090 | "str"
1091 | "i8"
1092 | "i16"
1093 | "i32"
1094 | "i64"
1095 | "i128"
1096 | "isize"
1097 | "u8"
1098 | "u16"
1099 | "u32"
1100 | "u64"
1101 | "u128"
1102 | "usize"
1103 | "f32"
1104 | "f64"
1105 | "bool"
1106 | "char"
1107 )
1108 } else {
1109 false
1110 }
1111 }
1112 _ => false,
1113 };
1114
1115 if is_primitive {
1116 context_fields.push(quote! {
1117 context.insert(
1118 #field_name_str.to_string(),
1119 minijinja::Value::from_serialize(&self.#field_name)
1120 );
1121 });
1122 } else {
1123 context_fields.push(quote! {
1125 context.insert(
1126 #field_name_str.to_string(),
1127 minijinja::Value::from(self.#field_name.to_prompt())
1128 );
1129 });
1130 }
1131 }
1132 }
1133
1134 quote! {
1135 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1136 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1137 let mut parts = Vec::new();
1138
1139 #(#image_field_parts)*
1141
1142 let text = {
1144 let mut env = minijinja::Environment::new();
1145 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1146 panic!("Failed to parse template: {}", e)
1147 });
1148
1149 let tmpl = env.get_template("prompt").unwrap();
1150
1151 let mut context = std::collections::HashMap::new();
1152 #(#context_fields)*
1153
1154 tmpl.render(context).unwrap_or_else(|e| {
1155 format!("Failed to render prompt: {}", e)
1156 })
1157 };
1158
1159 if !text.is_empty() {
1160 parts.push(#crate_path::prompt::PromptPart::Text(text));
1161 }
1162
1163 parts
1164 }
1165
1166 fn to_prompt(&self) -> String {
1167 let mut env = minijinja::Environment::new();
1169 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1170 panic!("Failed to parse template: {}", e)
1171 });
1172
1173 let tmpl = env.get_template("prompt").unwrap();
1174
1175 let mut context = std::collections::HashMap::new();
1176 #(#context_fields)*
1177
1178 tmpl.render(context).unwrap_or_else(|e| {
1179 format!("Failed to render prompt: {}", e)
1180 })
1181 }
1182
1183 fn prompt_schema() -> String {
1184 String::new() }
1186 }
1187 }
1188 } else {
1189 quote! {
1191 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1192 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1193 let mut parts = Vec::new();
1194
1195 #(#image_field_parts)*
1197
1198 let text = #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1200 format!("Failed to render prompt: {}", e)
1201 });
1202 if !text.is_empty() {
1203 parts.push(#crate_path::prompt::PromptPart::Text(text));
1204 }
1205
1206 parts
1207 }
1208
1209 fn to_prompt(&self) -> String {
1210 #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1211 format!("Failed to render prompt: {}", e)
1212 })
1213 }
1214
1215 fn prompt_schema() -> String {
1216 String::new() }
1218 }
1219 }
1220 }
1221 } else {
1222 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1225 &fields.named
1226 } else {
1227 panic!(
1228 "Default prompt generation is only supported for structs with named fields."
1229 );
1230 };
1231
1232 let mut text_field_parts = Vec::new();
1234 let mut image_field_parts = Vec::new();
1235
1236 for f in fields.iter() {
1237 let field_name = f.ident.as_ref().unwrap();
1238 let attrs = parse_field_prompt_attrs(&f.attrs);
1239
1240 if attrs.skip {
1242 continue;
1243 }
1244
1245 if attrs.image {
1246 image_field_parts.push(quote! {
1248 parts.extend(self.#field_name.to_prompt_parts());
1249 });
1250 } else {
1251 let key = if let Some(rename) = attrs.rename {
1257 rename
1258 } else {
1259 let doc_comment = extract_doc_comments(&f.attrs);
1260 if !doc_comment.is_empty() {
1261 doc_comment
1262 } else {
1263 field_name.to_string()
1264 }
1265 };
1266
1267 let value_expr = if let Some(format_with) = attrs.format_with {
1269 let func_path: syn::Path =
1271 syn::parse_str(&format_with).unwrap_or_else(|_| {
1272 panic!("Invalid function path: {}", format_with)
1273 });
1274 quote! { #func_path(&self.#field_name) }
1275 } else {
1276 quote! { self.#field_name.to_prompt() }
1277 };
1278
1279 text_field_parts.push(quote! {
1280 text_parts.push(format!("{}: {}", #key, #value_expr));
1281 });
1282 }
1283 }
1284
1285 quote! {
1287 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1288 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1289 let mut parts = Vec::new();
1290
1291 #(#image_field_parts)*
1293
1294 let mut text_parts = Vec::new();
1296 #(#text_field_parts)*
1297
1298 if !text_parts.is_empty() {
1299 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1300 }
1301
1302 parts
1303 }
1304
1305 fn to_prompt(&self) -> String {
1306 let mut text_parts = Vec::new();
1307 #(#text_field_parts)*
1308 text_parts.join("\n")
1309 }
1310
1311 fn prompt_schema() -> String {
1312 String::new() }
1314 }
1315 }
1316 };
1317
1318 TokenStream::from(expanded)
1319 }
1320 Data::Union(_) => {
1321 panic!("`#[derive(ToPrompt)]` is not supported for unions");
1322 }
1323 }
1324}
1325
1326#[derive(Debug, Clone)]
1328struct TargetInfo {
1329 name: String,
1330 template: Option<String>,
1331 field_configs: std::collections::HashMap<String, FieldTargetConfig>,
1332}
1333
1334#[derive(Debug, Clone, Default)]
1336struct FieldTargetConfig {
1337 skip: bool,
1338 rename: Option<String>,
1339 format_with: Option<String>,
1340 image: bool,
1341 include_only: bool, }
1343
1344fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
1346 let mut configs = Vec::new();
1347
1348 for attr in attrs {
1349 if attr.path().is_ident("prompt_for")
1350 && let Ok(meta_list) = attr.meta.require_list()
1351 {
1352 if meta_list.tokens.to_string() == "skip" {
1354 let config = FieldTargetConfig {
1356 skip: true,
1357 ..Default::default()
1358 };
1359 configs.push(("*".to_string(), config));
1360 } else if let Ok(metas) =
1361 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1362 {
1363 let mut target_name = None;
1364 let mut config = FieldTargetConfig::default();
1365
1366 for meta in metas {
1367 match meta {
1368 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1369 if let syn::Expr::Lit(syn::ExprLit {
1370 lit: syn::Lit::Str(lit_str),
1371 ..
1372 }) = nv.value
1373 {
1374 target_name = Some(lit_str.value());
1375 }
1376 }
1377 Meta::Path(path) if path.is_ident("skip") => {
1378 config.skip = true;
1379 }
1380 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
1381 if let syn::Expr::Lit(syn::ExprLit {
1382 lit: syn::Lit::Str(lit_str),
1383 ..
1384 }) = nv.value
1385 {
1386 config.rename = Some(lit_str.value());
1387 }
1388 }
1389 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
1390 if let syn::Expr::Lit(syn::ExprLit {
1391 lit: syn::Lit::Str(lit_str),
1392 ..
1393 }) = nv.value
1394 {
1395 config.format_with = Some(lit_str.value());
1396 }
1397 }
1398 Meta::Path(path) if path.is_ident("image") => {
1399 config.image = true;
1400 }
1401 _ => {}
1402 }
1403 }
1404
1405 if let Some(name) = target_name {
1406 config.include_only = true;
1407 configs.push((name, config));
1408 }
1409 }
1410 }
1411 }
1412
1413 configs
1414}
1415
1416fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
1418 let mut targets = Vec::new();
1419
1420 for attr in attrs {
1421 if attr.path().is_ident("prompt_for")
1422 && let Ok(meta_list) = attr.meta.require_list()
1423 && let Ok(metas) =
1424 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1425 {
1426 let mut target_name = None;
1427 let mut template = None;
1428
1429 for meta in metas {
1430 match meta {
1431 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1432 if let syn::Expr::Lit(syn::ExprLit {
1433 lit: syn::Lit::Str(lit_str),
1434 ..
1435 }) = nv.value
1436 {
1437 target_name = Some(lit_str.value());
1438 }
1439 }
1440 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1441 if let syn::Expr::Lit(syn::ExprLit {
1442 lit: syn::Lit::Str(lit_str),
1443 ..
1444 }) = nv.value
1445 {
1446 template = Some(lit_str.value());
1447 }
1448 }
1449 _ => {}
1450 }
1451 }
1452
1453 if let Some(name) = target_name {
1454 targets.push(TargetInfo {
1455 name,
1456 template,
1457 field_configs: std::collections::HashMap::new(),
1458 });
1459 }
1460 }
1461 }
1462
1463 targets
1464}
1465
1466#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
1467pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
1468 let input = parse_macro_input!(input as DeriveInput);
1469
1470 let found_crate =
1471 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1472 let crate_path = match found_crate {
1473 FoundCrate::Itself => {
1474 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1476 quote!(::#ident)
1477 }
1478 FoundCrate::Name(name) => {
1479 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1480 quote!(::#ident)
1481 }
1482 };
1483
1484 let data_struct = match &input.data {
1486 Data::Struct(data) => data,
1487 _ => {
1488 return syn::Error::new(
1489 input.ident.span(),
1490 "`#[derive(ToPromptSet)]` is only supported for structs",
1491 )
1492 .to_compile_error()
1493 .into();
1494 }
1495 };
1496
1497 let fields = match &data_struct.fields {
1498 syn::Fields::Named(fields) => &fields.named,
1499 _ => {
1500 return syn::Error::new(
1501 input.ident.span(),
1502 "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
1503 )
1504 .to_compile_error()
1505 .into();
1506 }
1507 };
1508
1509 let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
1511
1512 for field in fields.iter() {
1514 let field_name = field.ident.as_ref().unwrap().to_string();
1515 let field_configs = parse_prompt_for_attrs(&field.attrs);
1516
1517 for (target_name, config) in field_configs {
1518 if target_name == "*" {
1519 for target in &mut targets {
1521 target
1522 .field_configs
1523 .entry(field_name.clone())
1524 .or_insert_with(FieldTargetConfig::default)
1525 .skip = config.skip;
1526 }
1527 } else {
1528 let target_exists = targets.iter().any(|t| t.name == target_name);
1530 if !target_exists {
1531 targets.push(TargetInfo {
1533 name: target_name.clone(),
1534 template: None,
1535 field_configs: std::collections::HashMap::new(),
1536 });
1537 }
1538
1539 let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
1540
1541 target.field_configs.insert(field_name.clone(), config);
1542 }
1543 }
1544 }
1545
1546 let mut match_arms = Vec::new();
1548
1549 for target in &targets {
1550 let target_name = &target.name;
1551
1552 if let Some(template_str) = &target.template {
1553 let mut image_parts = Vec::new();
1555
1556 for field in fields.iter() {
1557 let field_name = field.ident.as_ref().unwrap();
1558 let field_name_str = field_name.to_string();
1559
1560 if let Some(config) = target.field_configs.get(&field_name_str)
1561 && config.image
1562 {
1563 image_parts.push(quote! {
1564 parts.extend(self.#field_name.to_prompt_parts());
1565 });
1566 }
1567 }
1568
1569 match_arms.push(quote! {
1570 #target_name => {
1571 let mut parts = Vec::new();
1572
1573 #(#image_parts)*
1574
1575 let text = #crate_path::prompt::render_prompt(#template_str, self)
1576 .map_err(|e| #crate_path::prompt::PromptSetError::RenderFailed {
1577 target: #target_name.to_string(),
1578 source: e,
1579 })?;
1580
1581 if !text.is_empty() {
1582 parts.push(#crate_path::prompt::PromptPart::Text(text));
1583 }
1584
1585 Ok(parts)
1586 }
1587 });
1588 } else {
1589 let mut text_field_parts = Vec::new();
1591 let mut image_field_parts = Vec::new();
1592
1593 for field in fields.iter() {
1594 let field_name = field.ident.as_ref().unwrap();
1595 let field_name_str = field_name.to_string();
1596
1597 let config = target.field_configs.get(&field_name_str);
1599
1600 if let Some(cfg) = config
1602 && cfg.skip
1603 {
1604 continue;
1605 }
1606
1607 let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
1611 let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
1612 .iter()
1613 .any(|(name, _)| name != "*");
1614
1615 if has_any_target_specific_config && !is_explicitly_for_this_target {
1616 continue;
1617 }
1618
1619 if let Some(cfg) = config {
1620 if cfg.image {
1621 image_field_parts.push(quote! {
1622 parts.extend(self.#field_name.to_prompt_parts());
1623 });
1624 } else {
1625 let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
1626
1627 let value_expr = if let Some(format_with) = &cfg.format_with {
1628 match syn::parse_str::<syn::Path>(format_with) {
1630 Ok(func_path) => quote! { #func_path(&self.#field_name) },
1631 Err(_) => {
1632 let error_msg = format!(
1634 "Invalid function path in format_with: '{}'",
1635 format_with
1636 );
1637 quote! {
1638 compile_error!(#error_msg);
1639 String::new()
1640 }
1641 }
1642 }
1643 } else {
1644 quote! { self.#field_name.to_prompt() }
1645 };
1646
1647 text_field_parts.push(quote! {
1648 text_parts.push(format!("{}: {}", #key, #value_expr));
1649 });
1650 }
1651 } else {
1652 text_field_parts.push(quote! {
1654 text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
1655 });
1656 }
1657 }
1658
1659 match_arms.push(quote! {
1660 #target_name => {
1661 let mut parts = Vec::new();
1662
1663 #(#image_field_parts)*
1664
1665 let mut text_parts = Vec::new();
1666 #(#text_field_parts)*
1667
1668 if !text_parts.is_empty() {
1669 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1670 }
1671
1672 Ok(parts)
1673 }
1674 });
1675 }
1676 }
1677
1678 let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
1680
1681 match_arms.push(quote! {
1683 _ => {
1684 let available = vec![#(#target_names.to_string()),*];
1685 Err(#crate_path::prompt::PromptSetError::TargetNotFound {
1686 target: target.to_string(),
1687 available,
1688 })
1689 }
1690 });
1691
1692 let struct_name = &input.ident;
1693 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1694
1695 let expanded = quote! {
1696 impl #impl_generics #crate_path::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
1697 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<#crate_path::prompt::PromptPart>, #crate_path::prompt::PromptSetError> {
1698 match target {
1699 #(#match_arms)*
1700 }
1701 }
1702 }
1703 };
1704
1705 TokenStream::from(expanded)
1706}
1707
1708struct TypeList {
1710 types: Punctuated<syn::Type, Token![,]>,
1711}
1712
1713impl Parse for TypeList {
1714 fn parse(input: ParseStream) -> syn::Result<Self> {
1715 Ok(TypeList {
1716 types: Punctuated::parse_terminated(input)?,
1717 })
1718 }
1719}
1720
1721#[proc_macro]
1745pub fn examples_section(input: TokenStream) -> TokenStream {
1746 let input = parse_macro_input!(input as TypeList);
1747
1748 let found_crate =
1749 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1750 let _crate_path = match found_crate {
1751 FoundCrate::Itself => quote!(crate),
1752 FoundCrate::Name(name) => {
1753 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1754 quote!(::#ident)
1755 }
1756 };
1757
1758 let mut type_sections = Vec::new();
1760
1761 for ty in input.types.iter() {
1762 let type_name_str = quote!(#ty).to_string();
1764
1765 type_sections.push(quote! {
1767 {
1768 let type_name = #type_name_str;
1769 let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
1770 format!("---\n#### `{}`\n{}", type_name, json_example)
1771 }
1772 });
1773 }
1774
1775 let expanded = quote! {
1777 {
1778 let mut sections = Vec::new();
1779 sections.push("---".to_string());
1780 sections.push("### Examples".to_string());
1781 sections.push("".to_string());
1782 sections.push("Here are examples of the data structures you should use.".to_string());
1783 sections.push("".to_string());
1784
1785 #(sections.push(#type_sections);)*
1786
1787 sections.push("---".to_string());
1788
1789 sections.join("\n")
1790 }
1791 };
1792
1793 TokenStream::from(expanded)
1794}
1795
1796fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
1798 for attr in attrs {
1799 if attr.path().is_ident("prompt_for")
1800 && let Ok(meta_list) = attr.meta.require_list()
1801 && let Ok(metas) =
1802 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1803 {
1804 let mut target_type = None;
1805 let mut template = None;
1806
1807 for meta in metas {
1808 match meta {
1809 Meta::NameValue(nv) if nv.path.is_ident("target") => {
1810 if let syn::Expr::Lit(syn::ExprLit {
1811 lit: syn::Lit::Str(lit_str),
1812 ..
1813 }) = nv.value
1814 {
1815 target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
1817 }
1818 }
1819 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1820 if let syn::Expr::Lit(syn::ExprLit {
1821 lit: syn::Lit::Str(lit_str),
1822 ..
1823 }) = nv.value
1824 {
1825 template = Some(lit_str.value());
1826 }
1827 }
1828 _ => {}
1829 }
1830 }
1831
1832 if let (Some(target), Some(tmpl)) = (target_type, template) {
1833 return (target, tmpl);
1834 }
1835 }
1836 }
1837
1838 panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
1839}
1840
1841#[proc_macro_attribute]
1875pub fn define_intent(_attr: TokenStream, item: TokenStream) -> TokenStream {
1876 let input = parse_macro_input!(item as DeriveInput);
1877
1878 let found_crate =
1879 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1880 let crate_path = match found_crate {
1881 FoundCrate::Itself => {
1882 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1884 quote!(::#ident)
1885 }
1886 FoundCrate::Name(name) => {
1887 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1888 quote!(::#ident)
1889 }
1890 };
1891
1892 let enum_data = match &input.data {
1894 Data::Enum(data) => data,
1895 _ => {
1896 return syn::Error::new(
1897 input.ident.span(),
1898 "`#[define_intent]` can only be applied to enums",
1899 )
1900 .to_compile_error()
1901 .into();
1902 }
1903 };
1904
1905 let mut prompt_template = None;
1907 let mut extractor_tag = None;
1908 let mut mode = None;
1909
1910 for attr in &input.attrs {
1911 if attr.path().is_ident("intent")
1912 && let Ok(metas) =
1913 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1914 {
1915 for meta in metas {
1916 match meta {
1917 Meta::NameValue(nv) if nv.path.is_ident("prompt") => {
1918 if let syn::Expr::Lit(syn::ExprLit {
1919 lit: syn::Lit::Str(lit_str),
1920 ..
1921 }) = nv.value
1922 {
1923 prompt_template = Some(lit_str.value());
1924 }
1925 }
1926 Meta::NameValue(nv) if nv.path.is_ident("extractor_tag") => {
1927 if let syn::Expr::Lit(syn::ExprLit {
1928 lit: syn::Lit::Str(lit_str),
1929 ..
1930 }) = nv.value
1931 {
1932 extractor_tag = Some(lit_str.value());
1933 }
1934 }
1935 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
1936 if let syn::Expr::Lit(syn::ExprLit {
1937 lit: syn::Lit::Str(lit_str),
1938 ..
1939 }) = nv.value
1940 {
1941 mode = Some(lit_str.value());
1942 }
1943 }
1944 _ => {}
1945 }
1946 }
1947 }
1948 }
1949
1950 let mode = mode.unwrap_or_else(|| "single".to_string());
1952
1953 if mode != "single" && mode != "multi_tag" {
1955 return syn::Error::new(
1956 input.ident.span(),
1957 "`mode` must be either \"single\" or \"multi_tag\"",
1958 )
1959 .to_compile_error()
1960 .into();
1961 }
1962
1963 let prompt_template = match prompt_template {
1965 Some(p) => p,
1966 None => {
1967 return syn::Error::new(
1968 input.ident.span(),
1969 "`#[intent(...)]` attribute must include `prompt = \"...\"`",
1970 )
1971 .to_compile_error()
1972 .into();
1973 }
1974 };
1975
1976 if mode == "multi_tag" {
1978 let enum_name = &input.ident;
1979 let actions_doc = generate_multi_tag_actions_doc(&enum_data.variants);
1980 return generate_multi_tag_output(
1981 &input,
1982 enum_name,
1983 enum_data,
1984 prompt_template,
1985 actions_doc,
1986 );
1987 }
1988
1989 let extractor_tag = match extractor_tag {
1991 Some(t) => t,
1992 None => {
1993 return syn::Error::new(
1994 input.ident.span(),
1995 "`#[intent(...)]` attribute must include `extractor_tag = \"...\"`",
1996 )
1997 .to_compile_error()
1998 .into();
1999 }
2000 };
2001
2002 let enum_name = &input.ident;
2004 let enum_docs = extract_doc_comments(&input.attrs);
2005
2006 let mut intents_doc_lines = Vec::new();
2007
2008 if !enum_docs.is_empty() {
2010 intents_doc_lines.push(format!("{}: {}", enum_name, enum_docs));
2011 } else {
2012 intents_doc_lines.push(format!("{}:", enum_name));
2013 }
2014 intents_doc_lines.push(String::new()); intents_doc_lines.push("Possible values:".to_string());
2016
2017 for variant in &enum_data.variants {
2019 let variant_name = &variant.ident;
2020 let variant_docs = extract_doc_comments(&variant.attrs);
2021
2022 if !variant_docs.is_empty() {
2023 intents_doc_lines.push(format!("- {}: {}", variant_name, variant_docs));
2024 } else {
2025 intents_doc_lines.push(format!("- {}", variant_name));
2026 }
2027 }
2028
2029 let intents_doc_str = intents_doc_lines.join("\n");
2030
2031 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2033 let user_variables: Vec<String> = placeholders
2034 .iter()
2035 .filter_map(|(name, _)| {
2036 if name != "intents_doc" {
2037 Some(name.clone())
2038 } else {
2039 None
2040 }
2041 })
2042 .collect();
2043
2044 let enum_name_str = enum_name.to_string();
2046 let snake_case_name = to_snake_case(&enum_name_str);
2047 let function_name = syn::Ident::new(
2048 &format!("build_{}_prompt", snake_case_name),
2049 proc_macro2::Span::call_site(),
2050 );
2051
2052 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2054 .iter()
2055 .map(|var| {
2056 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2057 quote! { #ident: &str }
2058 })
2059 .collect();
2060
2061 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2063 .iter()
2064 .map(|var| {
2065 let var_str = var.clone();
2066 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2067 quote! {
2068 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2069 }
2070 })
2071 .collect();
2072
2073 let converted_template = prompt_template.clone();
2075
2076 let extractor_name = syn::Ident::new(
2078 &format!("{}Extractor", enum_name),
2079 proc_macro2::Span::call_site(),
2080 );
2081
2082 let filtered_attrs: Vec<_> = input
2084 .attrs
2085 .iter()
2086 .filter(|attr| !attr.path().is_ident("intent"))
2087 .collect();
2088
2089 let vis = &input.vis;
2091 let generics = &input.generics;
2092 let variants = &enum_data.variants;
2093 let enum_output = quote! {
2094 #(#filtered_attrs)*
2095 #vis enum #enum_name #generics {
2096 #variants
2097 }
2098 };
2099
2100 let expanded = quote! {
2102 #enum_output
2104
2105 pub fn #function_name(#(#function_params),*) -> String {
2107 let mut env = minijinja::Environment::new();
2108 env.add_template("prompt", #converted_template)
2109 .expect("Failed to parse intent prompt template");
2110
2111 let tmpl = env.get_template("prompt").unwrap();
2112
2113 let mut __template_context = std::collections::HashMap::new();
2114
2115 __template_context.insert("intents_doc".to_string(), minijinja::Value::from(#intents_doc_str));
2117
2118 #(#context_insertions)*
2120
2121 tmpl.render(&__template_context)
2122 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2123 }
2124
2125 pub struct #extractor_name;
2127
2128 impl #extractor_name {
2129 pub const EXTRACTOR_TAG: &'static str = #extractor_tag;
2130 }
2131
2132 impl #crate_path::intent::IntentExtractor<#enum_name> for #extractor_name {
2133 fn extract_intent(&self, response: &str) -> Result<#enum_name, #crate_path::intent::IntentExtractionError> {
2134 #crate_path::intent::extract_intent_from_response(response, Self::EXTRACTOR_TAG)
2136 }
2137 }
2138 };
2139
2140 TokenStream::from(expanded)
2141}
2142
2143fn to_snake_case(s: &str) -> String {
2145 let mut result = String::new();
2146 let mut prev_upper = false;
2147
2148 for (i, ch) in s.chars().enumerate() {
2149 if ch.is_uppercase() {
2150 if i > 0 && !prev_upper {
2151 result.push('_');
2152 }
2153 result.push(ch.to_lowercase().next().unwrap());
2154 prev_upper = true;
2155 } else {
2156 result.push(ch);
2157 prev_upper = false;
2158 }
2159 }
2160
2161 result
2162}
2163
2164#[derive(Debug, Default)]
2166struct ActionAttrs {
2167 tag: Option<String>,
2168}
2169
2170fn parse_action_attrs(attrs: &[syn::Attribute]) -> ActionAttrs {
2171 let mut result = ActionAttrs::default();
2172
2173 for attr in attrs {
2174 if attr.path().is_ident("action")
2175 && let Ok(meta_list) = attr.meta.require_list()
2176 && let Ok(metas) =
2177 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2178 {
2179 for meta in metas {
2180 if let Meta::NameValue(nv) = meta
2181 && nv.path.is_ident("tag")
2182 && let syn::Expr::Lit(syn::ExprLit {
2183 lit: syn::Lit::Str(lit_str),
2184 ..
2185 }) = nv.value
2186 {
2187 result.tag = Some(lit_str.value());
2188 }
2189 }
2190 }
2191 }
2192
2193 result
2194}
2195
2196#[derive(Debug, Default)]
2198struct FieldActionAttrs {
2199 is_attribute: bool,
2200 is_inner_text: bool,
2201}
2202
2203fn parse_field_action_attrs(attrs: &[syn::Attribute]) -> FieldActionAttrs {
2204 let mut result = FieldActionAttrs::default();
2205
2206 for attr in attrs {
2207 if attr.path().is_ident("action")
2208 && let Ok(meta_list) = attr.meta.require_list()
2209 {
2210 let tokens_str = meta_list.tokens.to_string();
2211 if tokens_str == "attribute" {
2212 result.is_attribute = true;
2213 } else if tokens_str == "inner_text" {
2214 result.is_inner_text = true;
2215 }
2216 }
2217 }
2218
2219 result
2220}
2221
2222fn generate_multi_tag_actions_doc(
2224 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2225) -> String {
2226 let mut doc_lines = Vec::new();
2227
2228 for variant in variants {
2229 let action_attrs = parse_action_attrs(&variant.attrs);
2230
2231 if let Some(tag) = action_attrs.tag {
2232 let variant_docs = extract_doc_comments(&variant.attrs);
2233
2234 match &variant.fields {
2235 syn::Fields::Unit => {
2236 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2238 }
2239 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
2240 doc_lines.push(format!("- `<{}>...</{}>`: {}", tag, tag, variant_docs));
2242 }
2243 syn::Fields::Named(fields) => {
2244 let mut attrs_str = Vec::new();
2246 let mut has_inner_text = false;
2247
2248 for field in &fields.named {
2249 let field_name = field.ident.as_ref().unwrap();
2250 let field_attrs = parse_field_action_attrs(&field.attrs);
2251
2252 if field_attrs.is_attribute {
2253 attrs_str.push(format!("{}=\"...\"", field_name));
2254 } else if field_attrs.is_inner_text {
2255 has_inner_text = true;
2256 }
2257 }
2258
2259 let attrs_part = if !attrs_str.is_empty() {
2260 format!(" {}", attrs_str.join(" "))
2261 } else {
2262 String::new()
2263 };
2264
2265 if has_inner_text {
2266 doc_lines.push(format!(
2267 "- `<{}{}>...</{}>`: {}",
2268 tag, attrs_part, tag, variant_docs
2269 ));
2270 } else if !attrs_str.is_empty() {
2271 doc_lines.push(format!("- `<{}{} />`: {}", tag, attrs_part, variant_docs));
2272 } else {
2273 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2274 }
2275
2276 for field in &fields.named {
2278 let field_name = field.ident.as_ref().unwrap();
2279 let field_attrs = parse_field_action_attrs(&field.attrs);
2280 let field_docs = extract_doc_comments(&field.attrs);
2281
2282 if field_attrs.is_attribute {
2283 doc_lines
2284 .push(format!(" - `{}` (attribute): {}", field_name, field_docs));
2285 } else if field_attrs.is_inner_text {
2286 doc_lines
2287 .push(format!(" - `{}` (inner_text): {}", field_name, field_docs));
2288 }
2289 }
2290 }
2291 _ => {
2292 }
2294 }
2295 }
2296 }
2297
2298 doc_lines.join("\n")
2299}
2300
2301fn generate_tags_regex(
2303 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2304) -> String {
2305 let mut tag_names = Vec::new();
2306
2307 for variant in variants {
2308 let action_attrs = parse_action_attrs(&variant.attrs);
2309 if let Some(tag) = action_attrs.tag {
2310 tag_names.push(tag);
2311 }
2312 }
2313
2314 if tag_names.is_empty() {
2315 return String::new();
2316 }
2317
2318 let tags_pattern = tag_names.join("|");
2319 format!(
2322 r"(?is)<(?:{})\b[^>]*/>|<(?:{})\b[^>]*>.*?</(?:{})>",
2323 tags_pattern, tags_pattern, tags_pattern
2324 )
2325}
2326
2327fn generate_multi_tag_output(
2329 input: &DeriveInput,
2330 enum_name: &syn::Ident,
2331 enum_data: &syn::DataEnum,
2332 prompt_template: String,
2333 actions_doc: String,
2334) -> TokenStream {
2335 let found_crate =
2336 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2337 let crate_path = match found_crate {
2338 FoundCrate::Itself => {
2339 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2341 quote!(::#ident)
2342 }
2343 FoundCrate::Name(name) => {
2344 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2345 quote!(::#ident)
2346 }
2347 };
2348
2349 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2351 let user_variables: Vec<String> = placeholders
2352 .iter()
2353 .filter_map(|(name, _)| {
2354 if name != "actions_doc" {
2355 Some(name.clone())
2356 } else {
2357 None
2358 }
2359 })
2360 .collect();
2361
2362 let enum_name_str = enum_name.to_string();
2364 let snake_case_name = to_snake_case(&enum_name_str);
2365 let function_name = syn::Ident::new(
2366 &format!("build_{}_prompt", snake_case_name),
2367 proc_macro2::Span::call_site(),
2368 );
2369
2370 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2372 .iter()
2373 .map(|var| {
2374 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2375 quote! { #ident: &str }
2376 })
2377 .collect();
2378
2379 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2381 .iter()
2382 .map(|var| {
2383 let var_str = var.clone();
2384 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2385 quote! {
2386 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2387 }
2388 })
2389 .collect();
2390
2391 let extractor_name = syn::Ident::new(
2393 &format!("{}Extractor", enum_name),
2394 proc_macro2::Span::call_site(),
2395 );
2396
2397 let filtered_attrs: Vec<_> = input
2399 .attrs
2400 .iter()
2401 .filter(|attr| !attr.path().is_ident("intent"))
2402 .collect();
2403
2404 let filtered_variants: Vec<proc_macro2::TokenStream> = enum_data
2406 .variants
2407 .iter()
2408 .map(|variant| {
2409 let variant_name = &variant.ident;
2410 let variant_attrs: Vec<_> = variant
2411 .attrs
2412 .iter()
2413 .filter(|attr| !attr.path().is_ident("action"))
2414 .collect();
2415 let fields = &variant.fields;
2416
2417 let filtered_fields = match fields {
2419 syn::Fields::Named(named_fields) => {
2420 let filtered: Vec<_> = named_fields
2421 .named
2422 .iter()
2423 .map(|field| {
2424 let field_name = &field.ident;
2425 let field_type = &field.ty;
2426 let field_vis = &field.vis;
2427 let filtered_attrs: Vec<_> = field
2428 .attrs
2429 .iter()
2430 .filter(|attr| !attr.path().is_ident("action"))
2431 .collect();
2432 quote! {
2433 #(#filtered_attrs)*
2434 #field_vis #field_name: #field_type
2435 }
2436 })
2437 .collect();
2438 quote! { { #(#filtered,)* } }
2439 }
2440 syn::Fields::Unnamed(unnamed_fields) => {
2441 let types: Vec<_> = unnamed_fields
2442 .unnamed
2443 .iter()
2444 .map(|field| {
2445 let field_type = &field.ty;
2446 quote! { #field_type }
2447 })
2448 .collect();
2449 quote! { (#(#types),*) }
2450 }
2451 syn::Fields::Unit => quote! {},
2452 };
2453
2454 quote! {
2455 #(#variant_attrs)*
2456 #variant_name #filtered_fields
2457 }
2458 })
2459 .collect();
2460
2461 let vis = &input.vis;
2462 let generics = &input.generics;
2463
2464 let parsing_arms = generate_parsing_arms(&enum_data.variants, enum_name);
2466
2467 let tags_regex = generate_tags_regex(&enum_data.variants);
2469
2470 let expanded = quote! {
2471 #(#filtered_attrs)*
2473 #vis enum #enum_name #generics {
2474 #(#filtered_variants),*
2475 }
2476
2477 pub fn #function_name(#(#function_params),*) -> String {
2479 let mut env = minijinja::Environment::new();
2480 env.add_template("prompt", #prompt_template)
2481 .expect("Failed to parse intent prompt template");
2482
2483 let tmpl = env.get_template("prompt").unwrap();
2484
2485 let mut __template_context = std::collections::HashMap::new();
2486
2487 __template_context.insert("actions_doc".to_string(), minijinja::Value::from(#actions_doc));
2489
2490 #(#context_insertions)*
2492
2493 tmpl.render(&__template_context)
2494 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2495 }
2496
2497 pub struct #extractor_name;
2499
2500 impl #extractor_name {
2501 fn parse_single_action(&self, text: &str) -> Option<#enum_name> {
2502 use ::quick_xml::events::Event;
2503 use ::quick_xml::Reader;
2504
2505 let mut actions = Vec::new();
2506 let mut reader = Reader::from_str(text);
2507 reader.config_mut().trim_text(true);
2508
2509 let mut buf = Vec::new();
2510
2511 loop {
2512 match reader.read_event_into(&mut buf) {
2513 Ok(Event::Start(e)) => {
2514 let owned_e = e.into_owned();
2515 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2516 let is_empty = false;
2517
2518 #parsing_arms
2519 }
2520 Ok(Event::Empty(e)) => {
2521 let owned_e = e.into_owned();
2522 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2523 let is_empty = true;
2524
2525 #parsing_arms
2526 }
2527 Ok(Event::Eof) => break,
2528 Err(_) => {
2529 break;
2531 }
2532 _ => {}
2533 }
2534 buf.clear();
2535 }
2536
2537 actions.into_iter().next()
2538 }
2539
2540 pub fn extract_actions(&self, text: &str) -> Result<Vec<#enum_name>, #crate_path::intent::IntentError> {
2541 use ::quick_xml::events::Event;
2542 use ::quick_xml::Reader;
2543
2544 let mut actions = Vec::new();
2545 let mut reader = Reader::from_str(text);
2546 reader.config_mut().trim_text(true);
2547
2548 let mut buf = Vec::new();
2549
2550 loop {
2551 match reader.read_event_into(&mut buf) {
2552 Ok(Event::Start(e)) => {
2553 let owned_e = e.into_owned();
2554 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2555 let is_empty = false;
2556
2557 #parsing_arms
2558 }
2559 Ok(Event::Empty(e)) => {
2560 let owned_e = e.into_owned();
2561 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2562 let is_empty = true;
2563
2564 #parsing_arms
2565 }
2566 Ok(Event::Eof) => break,
2567 Err(_) => {
2568 break;
2570 }
2571 _ => {}
2572 }
2573 buf.clear();
2574 }
2575
2576 Ok(actions)
2577 }
2578
2579 pub fn transform_actions<F>(&self, text: &str, mut transformer: F) -> String
2580 where
2581 F: FnMut(#enum_name) -> String,
2582 {
2583 use ::regex::Regex;
2584
2585 let regex_pattern = #tags_regex;
2586 if regex_pattern.is_empty() {
2587 return text.to_string();
2588 }
2589
2590 let re = Regex::new(®ex_pattern).unwrap_or_else(|e| {
2591 panic!("Failed to compile regex for action tags: {}", e);
2592 });
2593
2594 re.replace_all(text, |caps: &::regex::Captures| {
2595 let matched = caps.get(0).map(|m| m.as_str()).unwrap_or("");
2596
2597 if let Some(action) = self.parse_single_action(matched) {
2599 transformer(action)
2600 } else {
2601 matched.to_string()
2603 }
2604 }).to_string()
2605 }
2606
2607 pub fn strip_actions(&self, text: &str) -> String {
2608 self.transform_actions(text, |_| String::new())
2609 }
2610 }
2611 };
2612
2613 TokenStream::from(expanded)
2614}
2615
2616fn generate_parsing_arms(
2618 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2619 enum_name: &syn::Ident,
2620) -> proc_macro2::TokenStream {
2621 let mut arms = Vec::new();
2622
2623 for variant in variants {
2624 let variant_name = &variant.ident;
2625 let action_attrs = parse_action_attrs(&variant.attrs);
2626
2627 if let Some(tag) = action_attrs.tag {
2628 match &variant.fields {
2629 syn::Fields::Unit => {
2630 arms.push(quote! {
2632 if &tag_name == #tag {
2633 actions.push(#enum_name::#variant_name);
2634 }
2635 });
2636 }
2637 syn::Fields::Unnamed(_fields) => {
2638 arms.push(quote! {
2640 if &tag_name == #tag && !is_empty {
2641 match reader.read_text(owned_e.name()) {
2643 Ok(text) => {
2644 actions.push(#enum_name::#variant_name(text.to_string()));
2645 }
2646 Err(_) => {
2647 actions.push(#enum_name::#variant_name(String::new()));
2649 }
2650 }
2651 }
2652 });
2653 }
2654 syn::Fields::Named(fields) => {
2655 let mut field_names = Vec::new();
2657 let mut has_inner_text_field = None;
2658
2659 for field in &fields.named {
2660 let field_name = field.ident.as_ref().unwrap();
2661 let field_attrs = parse_field_action_attrs(&field.attrs);
2662
2663 if field_attrs.is_attribute {
2664 field_names.push(field_name.clone());
2665 } else if field_attrs.is_inner_text {
2666 has_inner_text_field = Some(field_name.clone());
2667 }
2668 }
2669
2670 if let Some(inner_text_field) = has_inner_text_field {
2671 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2674 quote! {
2675 let mut #field_name = String::new();
2676 for attr in owned_e.attributes() {
2677 if let Ok(attr) = attr {
2678 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2679 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2680 break;
2681 }
2682 }
2683 }
2684 }
2685 }).collect();
2686
2687 arms.push(quote! {
2688 if &tag_name == #tag {
2689 #(#attr_extractions)*
2690
2691 if is_empty {
2693 let #inner_text_field = String::new();
2694 actions.push(#enum_name::#variant_name {
2695 #(#field_names,)*
2696 #inner_text_field,
2697 });
2698 } else {
2699 match reader.read_text(owned_e.name()) {
2701 Ok(text) => {
2702 let #inner_text_field = text.to_string();
2703 actions.push(#enum_name::#variant_name {
2704 #(#field_names,)*
2705 #inner_text_field,
2706 });
2707 }
2708 Err(_) => {
2709 let #inner_text_field = String::new();
2711 actions.push(#enum_name::#variant_name {
2712 #(#field_names,)*
2713 #inner_text_field,
2714 });
2715 }
2716 }
2717 }
2718 }
2719 });
2720 } else {
2721 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2723 quote! {
2724 let mut #field_name = String::new();
2725 for attr in owned_e.attributes() {
2726 if let Ok(attr) = attr {
2727 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2728 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2729 break;
2730 }
2731 }
2732 }
2733 }
2734 }).collect();
2735
2736 arms.push(quote! {
2737 if &tag_name == #tag {
2738 #(#attr_extractions)*
2739 actions.push(#enum_name::#variant_name {
2740 #(#field_names),*
2741 });
2742 }
2743 });
2744 }
2745 }
2746 }
2747 }
2748 }
2749
2750 quote! {
2751 #(#arms)*
2752 }
2753}
2754
2755#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
2757pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
2758 let input = parse_macro_input!(input as DeriveInput);
2759
2760 let found_crate =
2761 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2762 let crate_path = match found_crate {
2763 FoundCrate::Itself => {
2764 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2766 quote!(::#ident)
2767 }
2768 FoundCrate::Name(name) => {
2769 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2770 quote!(::#ident)
2771 }
2772 };
2773
2774 let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
2776
2777 let struct_name = &input.ident;
2778 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2779
2780 let placeholders = parse_template_placeholders_with_mode(&template);
2782
2783 let mut converted_template = template.clone();
2785 let mut context_fields = Vec::new();
2786
2787 let fields = match &input.data {
2789 Data::Struct(data_struct) => match &data_struct.fields {
2790 syn::Fields::Named(fields) => &fields.named,
2791 _ => panic!("ToPromptFor is only supported for structs with named fields"),
2792 },
2793 _ => panic!("ToPromptFor is only supported for structs"),
2794 };
2795
2796 let has_mode_support = input.attrs.iter().any(|attr| {
2798 if attr.path().is_ident("prompt")
2799 && let Ok(metas) =
2800 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2801 {
2802 for meta in metas {
2803 if let Meta::NameValue(nv) = meta
2804 && nv.path.is_ident("mode")
2805 {
2806 return true;
2807 }
2808 }
2809 }
2810 false
2811 });
2812
2813 for (placeholder_name, mode_opt) in &placeholders {
2815 if placeholder_name == "self" {
2816 if let Some(specific_mode) = mode_opt {
2817 let unique_key = format!("self__{}", specific_mode);
2819
2820 let pattern = format!("{{{{ self:{} }}}}", specific_mode);
2822 let replacement = format!("{{{{ {} }}}}", unique_key);
2823 converted_template = converted_template.replace(&pattern, &replacement);
2824
2825 context_fields.push(quote! {
2827 context.insert(
2828 #unique_key.to_string(),
2829 minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
2830 );
2831 });
2832 } else {
2833 if has_mode_support {
2836 context_fields.push(quote! {
2838 context.insert(
2839 "self".to_string(),
2840 minijinja::Value::from(self.to_prompt_with_mode(mode))
2841 );
2842 });
2843 } else {
2844 context_fields.push(quote! {
2846 context.insert(
2847 "self".to_string(),
2848 minijinja::Value::from(self.to_prompt())
2849 );
2850 });
2851 }
2852 }
2853 } else {
2854 let field_exists = fields.iter().any(|f| {
2857 f.ident
2858 .as_ref()
2859 .is_some_and(|ident| ident == placeholder_name)
2860 });
2861
2862 if field_exists {
2863 let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
2864
2865 context_fields.push(quote! {
2869 context.insert(
2870 #placeholder_name.to_string(),
2871 minijinja::Value::from_serialize(&self.#field_ident)
2872 );
2873 });
2874 }
2875 }
2877 }
2878
2879 let expanded = quote! {
2880 impl #impl_generics #crate_path::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
2881 where
2882 #target_type: serde::Serialize,
2883 {
2884 fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
2885 let mut env = minijinja::Environment::new();
2887 env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
2888 panic!("Failed to parse template: {}", e)
2889 });
2890
2891 let tmpl = env.get_template("prompt").unwrap();
2892
2893 let mut context = std::collections::HashMap::new();
2895 context.insert(
2897 "self".to_string(),
2898 minijinja::Value::from_serialize(self)
2899 );
2900 context.insert(
2902 "target".to_string(),
2903 minijinja::Value::from_serialize(target)
2904 );
2905 #(#context_fields)*
2906
2907 tmpl.render(context).unwrap_or_else(|e| {
2909 format!("Failed to render prompt: {}", e)
2910 })
2911 }
2912 }
2913 };
2914
2915 TokenStream::from(expanded)
2916}
2917
2918struct AgentAttrs {
2924 expertise: Option<String>,
2925 output: Option<syn::Type>,
2926 backend: Option<String>,
2927 model: Option<String>,
2928 inner: Option<String>,
2929 default_inner: Option<String>,
2930 max_retries: Option<u32>,
2931 profile: Option<String>,
2932}
2933
2934impl Parse for AgentAttrs {
2935 fn parse(input: ParseStream) -> syn::Result<Self> {
2936 let mut expertise = None;
2937 let mut output = None;
2938 let mut backend = None;
2939 let mut model = None;
2940 let mut inner = None;
2941 let mut default_inner = None;
2942 let mut max_retries = None;
2943 let mut profile = None;
2944
2945 let pairs = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
2946
2947 for meta in pairs {
2948 match meta {
2949 Meta::NameValue(nv) if nv.path.is_ident("expertise") => {
2950 if let syn::Expr::Lit(syn::ExprLit {
2951 lit: syn::Lit::Str(lit_str),
2952 ..
2953 }) = &nv.value
2954 {
2955 expertise = Some(lit_str.value());
2956 }
2957 }
2958 Meta::NameValue(nv) if nv.path.is_ident("output") => {
2959 if let syn::Expr::Lit(syn::ExprLit {
2960 lit: syn::Lit::Str(lit_str),
2961 ..
2962 }) = &nv.value
2963 {
2964 let ty: syn::Type = syn::parse_str(&lit_str.value())?;
2965 output = Some(ty);
2966 }
2967 }
2968 Meta::NameValue(nv) if nv.path.is_ident("backend") => {
2969 if let syn::Expr::Lit(syn::ExprLit {
2970 lit: syn::Lit::Str(lit_str),
2971 ..
2972 }) = &nv.value
2973 {
2974 backend = Some(lit_str.value());
2975 }
2976 }
2977 Meta::NameValue(nv) if nv.path.is_ident("model") => {
2978 if let syn::Expr::Lit(syn::ExprLit {
2979 lit: syn::Lit::Str(lit_str),
2980 ..
2981 }) = &nv.value
2982 {
2983 model = Some(lit_str.value());
2984 }
2985 }
2986 Meta::NameValue(nv) if nv.path.is_ident("inner") => {
2987 if let syn::Expr::Lit(syn::ExprLit {
2988 lit: syn::Lit::Str(lit_str),
2989 ..
2990 }) = &nv.value
2991 {
2992 inner = Some(lit_str.value());
2993 }
2994 }
2995 Meta::NameValue(nv) if nv.path.is_ident("default_inner") => {
2996 if let syn::Expr::Lit(syn::ExprLit {
2997 lit: syn::Lit::Str(lit_str),
2998 ..
2999 }) = &nv.value
3000 {
3001 default_inner = Some(lit_str.value());
3002 }
3003 }
3004 Meta::NameValue(nv) if nv.path.is_ident("max_retries") => {
3005 if let syn::Expr::Lit(syn::ExprLit {
3006 lit: syn::Lit::Int(lit_int),
3007 ..
3008 }) = &nv.value
3009 {
3010 max_retries = Some(lit_int.base10_parse()?);
3011 }
3012 }
3013 Meta::NameValue(nv) if nv.path.is_ident("profile") => {
3014 if let syn::Expr::Lit(syn::ExprLit {
3015 lit: syn::Lit::Str(lit_str),
3016 ..
3017 }) = &nv.value
3018 {
3019 profile = Some(lit_str.value());
3020 }
3021 }
3022 _ => {}
3023 }
3024 }
3025
3026 Ok(AgentAttrs {
3027 expertise,
3028 output,
3029 backend,
3030 model,
3031 inner,
3032 default_inner,
3033 max_retries,
3034 profile,
3035 })
3036 }
3037}
3038
3039fn parse_agent_attrs(attrs: &[syn::Attribute]) -> syn::Result<AgentAttrs> {
3041 for attr in attrs {
3042 if attr.path().is_ident("agent") {
3043 return attr.parse_args::<AgentAttrs>();
3044 }
3045 }
3046
3047 Ok(AgentAttrs {
3048 expertise: None,
3049 output: None,
3050 backend: None,
3051 model: None,
3052 inner: None,
3053 default_inner: None,
3054 max_retries: None,
3055 profile: None,
3056 })
3057}
3058
3059fn generate_backend_constructors(
3061 struct_name: &syn::Ident,
3062 backend: &str,
3063 _model: Option<&str>,
3064 _profile: Option<&str>,
3065 crate_path: &proc_macro2::TokenStream,
3066) -> proc_macro2::TokenStream {
3067 match backend {
3068 "claude" => {
3069 quote! {
3070 impl #struct_name {
3071 pub fn with_claude() -> Self {
3073 Self::new(#crate_path::agent::impls::ClaudeCodeAgent::new())
3074 }
3075
3076 pub fn with_claude_model(model: &str) -> Self {
3078 Self::new(
3079 #crate_path::agent::impls::ClaudeCodeAgent::new()
3080 .with_model_str(model)
3081 )
3082 }
3083 }
3084 }
3085 }
3086 "gemini" => {
3087 quote! {
3088 impl #struct_name {
3089 pub fn with_gemini() -> Self {
3091 Self::new(#crate_path::agent::impls::GeminiAgent::new())
3092 }
3093
3094 pub fn with_gemini_model(model: &str) -> Self {
3096 Self::new(
3097 #crate_path::agent::impls::GeminiAgent::new()
3098 .with_model_str(model)
3099 )
3100 }
3101 }
3102 }
3103 }
3104 _ => quote! {},
3105 }
3106}
3107
3108fn generate_default_impl(
3110 struct_name: &syn::Ident,
3111 backend: &str,
3112 model: Option<&str>,
3113 profile: Option<&str>,
3114 crate_path: &proc_macro2::TokenStream,
3115) -> proc_macro2::TokenStream {
3116 let profile_expr = if let Some(profile_str) = profile {
3118 match profile_str.to_lowercase().as_str() {
3119 "creative" => quote! { #crate_path::agent::ExecutionProfile::Creative },
3120 "balanced" => quote! { #crate_path::agent::ExecutionProfile::Balanced },
3121 "deterministic" => quote! { #crate_path::agent::ExecutionProfile::Deterministic },
3122 _ => quote! { #crate_path::agent::ExecutionProfile::Balanced }, }
3124 } else {
3125 quote! { #crate_path::agent::ExecutionProfile::default() }
3126 };
3127
3128 let agent_init = match backend {
3129 "gemini" => {
3130 let mut builder = quote! { #crate_path::agent::impls::GeminiAgent::new() };
3131
3132 if let Some(model_str) = model {
3133 builder = quote! { #builder.with_model_str(#model_str) };
3134 }
3135
3136 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3137 builder
3138 }
3139 _ => {
3140 let mut builder = quote! { #crate_path::agent::impls::ClaudeCodeAgent::new() };
3142
3143 if let Some(model_str) = model {
3144 builder = quote! { #builder.with_model_str(#model_str) };
3145 }
3146
3147 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3148 builder
3149 }
3150 };
3151
3152 quote! {
3153 impl Default for #struct_name {
3154 fn default() -> Self {
3155 Self::new(#agent_init)
3156 }
3157 }
3158 }
3159}
3160
3161#[proc_macro_derive(Agent, attributes(agent))]
3170pub fn derive_agent(input: TokenStream) -> TokenStream {
3171 let input = parse_macro_input!(input as DeriveInput);
3172 let struct_name = &input.ident;
3173
3174 let agent_attrs = match parse_agent_attrs(&input.attrs) {
3176 Ok(attrs) => attrs,
3177 Err(e) => return e.to_compile_error().into(),
3178 };
3179
3180 let expertise = agent_attrs
3181 .expertise
3182 .unwrap_or_else(|| String::from("general AI assistant"));
3183 let output_type = agent_attrs
3184 .output
3185 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3186 let backend = agent_attrs
3187 .backend
3188 .unwrap_or_else(|| String::from("claude"));
3189 let model = agent_attrs.model;
3190 let _profile = agent_attrs.profile; let max_retries = agent_attrs.max_retries.unwrap_or(3); let found_crate =
3195 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3196 let crate_path = match found_crate {
3197 FoundCrate::Itself => {
3198 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3200 quote!(::#ident)
3201 }
3202 FoundCrate::Name(name) => {
3203 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3204 quote!(::#ident)
3205 }
3206 };
3207
3208 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
3209
3210 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3212 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3213
3214 let enhanced_expertise = if is_string_output {
3216 quote! { #expertise }
3218 } else {
3219 let type_name = quote!(#output_type).to_string();
3221 quote! {
3222 {
3223 use std::sync::OnceLock;
3224 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3225
3226 EXPERTISE_CACHE.get_or_init(|| {
3227 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3229
3230 if schema.is_empty() {
3231 format!(
3233 concat!(
3234 #expertise,
3235 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3236 "Do not include any text outside the JSON object."
3237 ),
3238 #type_name
3239 )
3240 } else {
3241 format!(
3243 concat!(
3244 #expertise,
3245 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3246 ),
3247 schema
3248 )
3249 }
3250 }).as_str()
3251 }
3252 }
3253 };
3254
3255 let agent_init = match backend.as_str() {
3257 "gemini" => {
3258 if let Some(model_str) = model {
3259 quote! {
3260 use #crate_path::agent::impls::GeminiAgent;
3261 let agent = GeminiAgent::new().with_model_str(#model_str);
3262 }
3263 } else {
3264 quote! {
3265 use #crate_path::agent::impls::GeminiAgent;
3266 let agent = GeminiAgent::new();
3267 }
3268 }
3269 }
3270 "claude" => {
3271 if let Some(model_str) = model {
3272 quote! {
3273 use #crate_path::agent::impls::ClaudeCodeAgent;
3274 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3275 }
3276 } else {
3277 quote! {
3278 use #crate_path::agent::impls::ClaudeCodeAgent;
3279 let agent = ClaudeCodeAgent::new();
3280 }
3281 }
3282 }
3283 _ => {
3284 if let Some(model_str) = model {
3286 quote! {
3287 use #crate_path::agent::impls::ClaudeCodeAgent;
3288 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3289 }
3290 } else {
3291 quote! {
3292 use #crate_path::agent::impls::ClaudeCodeAgent;
3293 let agent = ClaudeCodeAgent::new();
3294 }
3295 }
3296 }
3297 };
3298
3299 let expanded = quote! {
3300 #[async_trait::async_trait]
3301 impl #impl_generics #crate_path::agent::Agent for #struct_name #ty_generics #where_clause {
3302 type Output = #output_type;
3303
3304 fn expertise(&self) -> &str {
3305 #enhanced_expertise
3306 }
3307
3308 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3309 #agent_init
3311
3312 let max_retries: u32 = #max_retries;
3314 let mut attempts = 0u32;
3315
3316 loop {
3317 attempts += 1;
3318
3319 let result = async {
3321 let response = agent.execute(intent.clone()).await?;
3322
3323 let json_str = #crate_path::extract_json(&response)
3325 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))?;
3326
3327 serde_json::from_str::<Self::Output>(&json_str)
3329 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))
3330 }.await;
3331
3332 match result {
3333 Ok(output) => return Ok(output),
3334 Err(e) if e.is_retryable() && attempts < max_retries => {
3335 log::warn!(
3337 "Agent execution failed (attempt {}/{}): {}. Retrying...",
3338 attempts,
3339 max_retries,
3340 e
3341 );
3342
3343 let delay_ms = 100 * attempts as u64;
3345 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
3346
3347 continue;
3349 }
3350 Err(e) => {
3351 if attempts > 1 {
3352 log::error!(
3353 "Agent execution failed after {} attempts: {}",
3354 attempts,
3355 e
3356 );
3357 }
3358 return Err(e);
3359 }
3360 }
3361 }
3362 }
3363
3364 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3365 #agent_init
3367 agent.is_available().await
3368 }
3369 }
3370 };
3371
3372 TokenStream::from(expanded)
3373}
3374
3375#[proc_macro_attribute]
3390pub fn agent(attr: TokenStream, item: TokenStream) -> TokenStream {
3391 let agent_attrs = match syn::parse::<AgentAttrs>(attr) {
3393 Ok(attrs) => attrs,
3394 Err(e) => return e.to_compile_error().into(),
3395 };
3396
3397 let input = parse_macro_input!(item as DeriveInput);
3399 let struct_name = &input.ident;
3400 let vis = &input.vis;
3401
3402 let expertise = agent_attrs
3403 .expertise
3404 .unwrap_or_else(|| String::from("general AI assistant"));
3405 let output_type = agent_attrs
3406 .output
3407 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3408 let backend = agent_attrs
3409 .backend
3410 .unwrap_or_else(|| String::from("claude"));
3411 let model = agent_attrs.model;
3412 let profile = agent_attrs.profile;
3413
3414 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3416 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3417
3418 let found_crate =
3420 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3421 let crate_path = match found_crate {
3422 FoundCrate::Itself => {
3423 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3424 quote!(::#ident)
3425 }
3426 FoundCrate::Name(name) => {
3427 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3428 quote!(::#ident)
3429 }
3430 };
3431
3432 let inner_generic_name = agent_attrs.inner.unwrap_or_else(|| String::from("A"));
3434 let inner_generic_ident = syn::Ident::new(&inner_generic_name, proc_macro2::Span::call_site());
3435
3436 let default_agent_type = if let Some(ref custom_type) = agent_attrs.default_inner {
3438 let type_path: syn::Type =
3440 syn::parse_str(custom_type).expect("default_inner must be a valid type path");
3441 quote! { #type_path }
3442 } else {
3443 match backend.as_str() {
3445 "gemini" => quote! { #crate_path::agent::impls::GeminiAgent },
3446 _ => quote! { #crate_path::agent::impls::ClaudeCodeAgent },
3447 }
3448 };
3449
3450 let struct_def = quote! {
3452 #vis struct #struct_name<#inner_generic_ident = #default_agent_type> {
3453 inner: #inner_generic_ident,
3454 }
3455 };
3456
3457 let constructors = quote! {
3459 impl<#inner_generic_ident> #struct_name<#inner_generic_ident> {
3460 pub fn new(inner: #inner_generic_ident) -> Self {
3462 Self { inner }
3463 }
3464 }
3465 };
3466
3467 let (backend_constructors, default_impl) = if agent_attrs.default_inner.is_some() {
3469 let default_impl = quote! {
3471 impl Default for #struct_name {
3472 fn default() -> Self {
3473 Self {
3474 inner: <#default_agent_type as Default>::default(),
3475 }
3476 }
3477 }
3478 };
3479 (quote! {}, default_impl)
3480 } else {
3481 let backend_constructors = generate_backend_constructors(
3483 struct_name,
3484 &backend,
3485 model.as_deref(),
3486 profile.as_deref(),
3487 &crate_path,
3488 );
3489 let default_impl = generate_default_impl(
3490 struct_name,
3491 &backend,
3492 model.as_deref(),
3493 profile.as_deref(),
3494 &crate_path,
3495 );
3496 (backend_constructors, default_impl)
3497 };
3498
3499 let enhanced_expertise = if is_string_output {
3501 quote! { #expertise }
3503 } else {
3504 let type_name = quote!(#output_type).to_string();
3506 quote! {
3507 {
3508 use std::sync::OnceLock;
3509 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3510
3511 EXPERTISE_CACHE.get_or_init(|| {
3512 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3514
3515 if schema.is_empty() {
3516 format!(
3518 concat!(
3519 #expertise,
3520 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3521 "Do not include any text outside the JSON object."
3522 ),
3523 #type_name
3524 )
3525 } else {
3526 format!(
3528 concat!(
3529 #expertise,
3530 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3531 ),
3532 schema
3533 )
3534 }
3535 }).as_str()
3536 }
3537 }
3538 };
3539
3540 let agent_impl = quote! {
3542 #[async_trait::async_trait]
3543 impl<#inner_generic_ident> #crate_path::agent::Agent for #struct_name<#inner_generic_ident>
3544 where
3545 #inner_generic_ident: #crate_path::agent::Agent<Output = String>,
3546 {
3547 type Output = #output_type;
3548
3549 fn expertise(&self) -> &str {
3550 #enhanced_expertise
3551 }
3552
3553 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3554 let enhanced_payload = intent.prepend_text(self.expertise());
3556
3557 let response = self.inner.execute(enhanced_payload).await?;
3559
3560 let json_str = #crate_path::extract_json(&response)
3562 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))?;
3563
3564 serde_json::from_str(&json_str)
3566 .map_err(|e| #crate_path::agent::AgentError::ParseError(e.to_string()))
3567 }
3568
3569 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3570 self.inner.is_available().await
3571 }
3572 }
3573 };
3574
3575 let expanded = quote! {
3576 #struct_def
3577 #constructors
3578 #backend_constructors
3579 #default_impl
3580 #agent_impl
3581 };
3582
3583 TokenStream::from(expanded)
3584}
3585
3586#[proc_macro_derive(TypeMarker)]
3608pub fn derive_type_marker(input: TokenStream) -> TokenStream {
3609 let input = parse_macro_input!(input as DeriveInput);
3610 let struct_name = &input.ident;
3611 let type_name_str = struct_name.to_string();
3612
3613 let found_crate =
3615 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3616 let crate_path = match found_crate {
3617 FoundCrate::Itself => {
3618 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3619 quote!(::#ident)
3620 }
3621 FoundCrate::Name(name) => {
3622 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3623 quote!(::#ident)
3624 }
3625 };
3626
3627 let expanded = quote! {
3628 impl #crate_path::orchestrator::TypeMarker for #struct_name {
3629 const TYPE_NAME: &'static str = #type_name_str;
3630 }
3631 };
3632
3633 TokenStream::from(expanded)
3634}
3635
3636#[proc_macro_attribute]
3672pub fn type_marker(_attr: TokenStream, item: TokenStream) -> TokenStream {
3673 let input = parse_macro_input!(item as syn::DeriveInput);
3674 let struct_name = &input.ident;
3675 let vis = &input.vis;
3676 let type_name_str = struct_name.to_string();
3677
3678 let default_fn_name = syn::Ident::new(
3680 &format!("default_{}_type", to_snake_case(&type_name_str)),
3681 struct_name.span(),
3682 );
3683
3684 let found_crate =
3686 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3687 let crate_path = match found_crate {
3688 FoundCrate::Itself => {
3689 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3690 quote!(::#ident)
3691 }
3692 FoundCrate::Name(name) => {
3693 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3694 quote!(::#ident)
3695 }
3696 };
3697
3698 let fields = match &input.data {
3700 syn::Data::Struct(data_struct) => match &data_struct.fields {
3701 syn::Fields::Named(fields) => &fields.named,
3702 _ => {
3703 return syn::Error::new_spanned(
3704 struct_name,
3705 "type_marker only works with structs with named fields",
3706 )
3707 .to_compile_error()
3708 .into();
3709 }
3710 },
3711 _ => {
3712 return syn::Error::new_spanned(struct_name, "type_marker only works with structs")
3713 .to_compile_error()
3714 .into();
3715 }
3716 };
3717
3718 let mut new_fields = vec![];
3720
3721 let default_fn_name_str = default_fn_name.to_string();
3723 let default_fn_name_lit = syn::LitStr::new(&default_fn_name_str, default_fn_name.span());
3724
3725 new_fields.push(quote! {
3730 #[serde(default = #default_fn_name_lit)]
3731 __type: String
3732 });
3733
3734 for field in fields {
3736 new_fields.push(quote! { #field });
3737 }
3738
3739 let attrs = &input.attrs;
3741 let generics = &input.generics;
3742
3743 let expanded = quote! {
3744 fn #default_fn_name() -> String {
3746 #type_name_str.to_string()
3747 }
3748
3749 #(#attrs)*
3751 #vis struct #struct_name #generics {
3752 #(#new_fields),*
3753 }
3754
3755 impl #crate_path::orchestrator::TypeMarker for #struct_name {
3757 const TYPE_NAME: &'static str = #type_name_str;
3758 }
3759 };
3760
3761 TokenStream::from(expanded)
3762}