1use proc_macro::TokenStream;
2use proc_macro_crate::{FoundCrate, crate_name};
3use quote::quote;
4use regex::Regex;
5use syn::{
6 Data, DeriveInput, Meta, Token,
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 punctuated::Punctuated,
10};
11
12fn parse_template_placeholders_with_mode(template: &str) -> Vec<(String, Option<String>)> {
15 let mut placeholders = Vec::new();
16 let mut seen_fields = std::collections::HashSet::new();
17
18 let mode_pattern = Regex::new(r"\{\{\s*(\w+)\s*:\s*(\w+)\s*\}\}").unwrap();
20 for cap in mode_pattern.captures_iter(template) {
21 let field_name = cap[1].to_string();
22 let mode = cap[2].to_string();
23 placeholders.push((field_name.clone(), Some(mode)));
24 seen_fields.insert(field_name);
25 }
26
27 let standard_pattern = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
29 for cap in standard_pattern.captures_iter(template) {
30 let field_name = cap[1].to_string();
31 if !seen_fields.contains(&field_name) {
33 placeholders.push((field_name, None));
34 }
35 }
36
37 placeholders
38}
39
40fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
42 attrs
43 .iter()
44 .filter_map(|attr| {
45 if attr.path().is_ident("doc")
46 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
47 && let syn::Expr::Lit(syn::ExprLit {
48 lit: syn::Lit::Str(lit_str),
49 ..
50 }) = &meta_name_value.value
51 {
52 return Some(lit_str.value());
53 }
54 None
55 })
56 .map(|s| s.trim().to_string())
57 .collect::<Vec<_>>()
58 .join(" ")
59}
60
61fn generate_example_only_parts(
63 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
64 has_default: bool,
65 crate_path: &proc_macro2::TokenStream,
66) -> proc_macro2::TokenStream {
67 let mut field_values = Vec::new();
68
69 for field in fields.iter() {
70 let field_name = field.ident.as_ref().unwrap();
71 let field_name_str = field_name.to_string();
72 let attrs = parse_field_prompt_attrs(&field.attrs);
73
74 if field_name_str == "__type" {
78 continue;
79 }
80
81 if attrs.skip {
83 continue;
84 }
85
86 if let Some(example) = attrs.example {
88 field_values.push(quote! {
90 json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
91 });
92 } else if has_default {
93 field_values.push(quote! {
95 let default_value = serde_json::to_value(&default_instance.#field_name)
96 .unwrap_or(serde_json::Value::Null);
97 json_obj.insert(#field_name_str.to_string(), default_value);
98 });
99 } else {
100 field_values.push(quote! {
102 let value = serde_json::to_value(&self.#field_name)
103 .unwrap_or(serde_json::Value::Null);
104 json_obj.insert(#field_name_str.to_string(), value);
105 });
106 }
107 }
108
109 if has_default {
110 quote! {
111 {
112 let default_instance = Self::default();
113 let mut json_obj = serde_json::Map::new();
114 #(#field_values)*
115 let json_value = serde_json::Value::Object(json_obj);
116 let json_str = serde_json::to_string_pretty(&json_value)
117 .unwrap_or_else(|_| "{}".to_string());
118 vec![#crate_path::prompt::PromptPart::Text(json_str)]
119 }
120 }
121 } else {
122 quote! {
123 {
124 let mut json_obj = serde_json::Map::new();
125 #(#field_values)*
126 let json_value = serde_json::Value::Object(json_obj);
127 let json_str = serde_json::to_string_pretty(&json_value)
128 .unwrap_or_else(|_| "{}".to_string());
129 vec![#crate_path::prompt::PromptPart::Text(json_str)]
130 }
131 }
132 }
133}
134
135fn generate_schema_only_parts(
137 struct_name: &str,
138 struct_docs: &str,
139 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
140 crate_path: &proc_macro2::TokenStream,
141 _has_type_marker: bool,
142) -> proc_macro2::TokenStream {
143 let mut field_schema_parts = vec![];
144
145 for field in fields.iter() {
147 let field_name = field.ident.as_ref().unwrap();
148 let field_name_str = field_name.to_string();
149 let attrs = parse_field_prompt_attrs(&field.attrs);
150
151 if field_name_str == "__type" {
155 continue;
156 }
157
158 if attrs.skip {
160 continue;
161 }
162
163 let field_docs = extract_doc_comments(&field.attrs);
165
166 let (is_vec, inner_type) = extract_vec_inner_type(&field.ty);
168
169 if is_vec {
170 let comment = if !field_docs.is_empty() {
173 format!(" // {}", field_docs)
174 } else {
175 String::new()
176 };
177
178 field_schema_parts.push(quote! {
179 {
180 let type_name = stringify!(#inner_type);
181 format!(" {}: {}[];{}", #field_name_str, type_name, #comment)
182 }
183 });
184 } else {
185 let field_type = &field.ty;
187 let is_primitive = is_primitive_type(field_type);
188
189 if !is_primitive {
190 let comment = if !field_docs.is_empty() {
193 format!(" // {}", field_docs)
194 } else {
195 String::new()
196 };
197
198 field_schema_parts.push(quote! {
199 {
200 let type_name = stringify!(#field_type);
201 format!(" {}: {};{}", #field_name_str, type_name, #comment)
202 }
203 });
204 } else {
205 let type_str = format_type_for_schema(&field.ty);
208 let comment = if !field_docs.is_empty() {
209 format!(" // {}", field_docs)
210 } else {
211 String::new()
212 };
213
214 field_schema_parts.push(quote! {
215 format!(" {}: {};{}", #field_name_str, #type_str, #comment)
216 });
217 }
218 }
219 }
220
221 let mut header_lines = Vec::new();
232
233 if !struct_docs.is_empty() {
235 header_lines.push("/**".to_string());
236 header_lines.push(format!(" * {}", struct_docs));
237 header_lines.push(" */".to_string());
238 }
239
240 header_lines.push(format!("type {} = {{", struct_name));
242
243 quote! {
244 {
245 let mut lines = vec![];
246 #(lines.push(#header_lines.to_string());)*
247 #(lines.push(#field_schema_parts);)*
248 lines.push("}".to_string());
249 vec![#crate_path::prompt::PromptPart::Text(lines.join("\n"))]
250 }
251 }
252}
253
254fn extract_vec_inner_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
256 if let syn::Type::Path(type_path) = ty
257 && let Some(last_segment) = type_path.path.segments.last()
258 && last_segment.ident == "Vec"
259 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
260 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
261 {
262 return (true, Some(inner_type));
263 }
264 (false, None)
265}
266
267fn is_primitive_type(ty: &syn::Type) -> bool {
269 if let syn::Type::Path(type_path) = ty
270 && let Some(last_segment) = type_path.path.segments.last()
271 {
272 let type_name = last_segment.ident.to_string();
273 matches!(
274 type_name.as_str(),
275 "String"
276 | "str"
277 | "i8"
278 | "i16"
279 | "i32"
280 | "i64"
281 | "i128"
282 | "isize"
283 | "u8"
284 | "u16"
285 | "u32"
286 | "u64"
287 | "u128"
288 | "usize"
289 | "f32"
290 | "f64"
291 | "bool"
292 | "Vec"
293 | "Option"
294 | "HashMap"
295 | "BTreeMap"
296 | "HashSet"
297 | "BTreeSet"
298 )
299 } else {
300 true
302 }
303}
304
305fn format_type_for_schema(ty: &syn::Type) -> String {
307 match ty {
309 syn::Type::Path(type_path) => {
310 let path = &type_path.path;
311 if let Some(last_segment) = path.segments.last() {
312 let type_name = last_segment.ident.to_string();
313
314 if type_name == "Option"
316 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
317 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
318 {
319 return format!("{} | null", format_type_for_schema(inner_type));
320 }
321
322 match type_name.as_str() {
324 "String" | "str" => "string".to_string(),
325 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
326 | "u64" | "u128" | "usize" => "number".to_string(),
327 "f32" | "f64" => "number".to_string(),
328 "bool" => "boolean".to_string(),
329 "Vec" => {
330 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
331 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
332 {
333 return format!("{}[]", format_type_for_schema(inner_type));
334 }
335 "array".to_string()
336 }
337 _ => type_name.to_lowercase(),
338 }
339 } else {
340 "unknown".to_string()
341 }
342 }
343 _ => "unknown".to_string(),
344 }
345}
346
347enum PromptAttribute {
349 Skip,
350 Description(String),
351 None,
352}
353
354fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
356 for attr in attrs {
357 if attr.path().is_ident("prompt") {
358 if let Ok(meta_list) = attr.meta.require_list() {
360 let tokens = &meta_list.tokens;
361 let tokens_str = tokens.to_string();
362 if tokens_str == "skip" {
363 return PromptAttribute::Skip;
364 }
365 }
366
367 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
369 return PromptAttribute::Description(lit_str.value());
370 }
371 }
372 }
373 PromptAttribute::None
374}
375
376#[derive(Debug, Default)]
378struct FieldPromptAttrs {
379 skip: bool,
380 rename: Option<String>,
381 format_with: Option<String>,
382 image: bool,
383 example: Option<String>,
384}
385
386fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
388 let mut result = FieldPromptAttrs::default();
389
390 for attr in attrs {
391 if attr.path().is_ident("prompt") {
392 if let Ok(meta_list) = attr.meta.require_list() {
394 if let Ok(metas) =
396 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
397 {
398 for meta in metas {
399 match meta {
400 Meta::Path(path) if path.is_ident("skip") => {
401 result.skip = true;
402 }
403 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
404 if let syn::Expr::Lit(syn::ExprLit {
405 lit: syn::Lit::Str(lit_str),
406 ..
407 }) = nv.value
408 {
409 result.rename = Some(lit_str.value());
410 }
411 }
412 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
413 if let syn::Expr::Lit(syn::ExprLit {
414 lit: syn::Lit::Str(lit_str),
415 ..
416 }) = nv.value
417 {
418 result.format_with = Some(lit_str.value());
419 }
420 }
421 Meta::Path(path) if path.is_ident("image") => {
422 result.image = true;
423 }
424 Meta::NameValue(nv) if nv.path.is_ident("example") => {
425 if let syn::Expr::Lit(syn::ExprLit {
426 lit: syn::Lit::Str(lit_str),
427 ..
428 }) = nv.value
429 {
430 result.example = Some(lit_str.value());
431 }
432 }
433 _ => {}
434 }
435 }
436 } else if meta_list.tokens.to_string() == "skip" {
437 result.skip = true;
439 } else if meta_list.tokens.to_string() == "image" {
440 result.image = true;
442 }
443 }
444 }
445 }
446
447 result
448}
449
450#[proc_macro_derive(ToPrompt, attributes(prompt))]
493pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
494 let input = parse_macro_input!(input as DeriveInput);
495
496 let found_crate =
497 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
498 let crate_path = match found_crate {
499 FoundCrate::Itself => {
500 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
502 quote!(::#ident)
503 }
504 FoundCrate::Name(name) => {
505 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
506 quote!(::#ident)
507 }
508 };
509
510 match &input.data {
512 Data::Enum(data_enum) => {
513 let enum_name = &input.ident;
515 let enum_docs = extract_doc_comments(&input.attrs);
516
517 let mut variant_lines = Vec::new();
530 let mut first_variant_name = None;
531
532 for variant in &data_enum.variants {
533 let variant_name = &variant.ident;
534 let variant_name_str = variant_name.to_string();
535
536 match parse_prompt_attribute(&variant.attrs) {
537 PromptAttribute::Skip => continue,
538 PromptAttribute::Description(desc) => {
539 variant_lines.push(format!(" | \"{}\" // {}", variant_name_str, desc));
540 if first_variant_name.is_none() {
541 first_variant_name = Some(variant_name_str);
542 }
543 }
544 PromptAttribute::None => {
545 let docs = extract_doc_comments(&variant.attrs);
546 if !docs.is_empty() {
547 variant_lines
548 .push(format!(" | \"{}\" // {}", variant_name_str, docs));
549 } else {
550 variant_lines.push(format!(" | \"{}\"", variant_name_str));
551 }
552 if first_variant_name.is_none() {
553 first_variant_name = Some(variant_name_str);
554 }
555 }
556 }
557 }
558
559 let mut lines = Vec::new();
561
562 if !enum_docs.is_empty() {
564 lines.push("/**".to_string());
565 lines.push(format!(" * {}", enum_docs));
566 lines.push(" */".to_string());
567 }
568
569 lines.push(format!("type {} =", enum_name));
571
572 for line in &variant_lines {
574 lines.push(line.clone());
575 }
576
577 if let Some(last) = lines.last_mut()
579 && !last.ends_with(';')
580 {
581 last.push(';');
582 }
583
584 if let Some(first_name) = first_variant_name {
586 lines.push("".to_string()); lines.push(format!("Example value: \"{}\"", first_name));
588 }
589
590 let prompt_string = lines.join("\n");
591 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
592
593 let mut match_arms = Vec::new();
595 for variant in &data_enum.variants {
596 let variant_name = &variant.ident;
597
598 match parse_prompt_attribute(&variant.attrs) {
600 PromptAttribute::Skip => {
601 match_arms.push(quote! {
603 Self::#variant_name => stringify!(#variant_name).to_string()
604 });
605 }
606 PromptAttribute::Description(desc) => {
607 match_arms.push(quote! {
609 Self::#variant_name => format!("{}: {}", stringify!(#variant_name), #desc)
610 });
611 }
612 PromptAttribute::None => {
613 let variant_docs = extract_doc_comments(&variant.attrs);
615 if !variant_docs.is_empty() {
616 match_arms.push(quote! {
617 Self::#variant_name => format!("{}: {}", stringify!(#variant_name), #variant_docs)
618 });
619 } else {
620 match_arms.push(quote! {
621 Self::#variant_name => stringify!(#variant_name).to_string()
622 });
623 }
624 }
625 }
626 }
627
628 let to_prompt_impl = if match_arms.is_empty() {
629 quote! {
631 fn to_prompt(&self) -> String {
632 match *self {}
633 }
634 }
635 } else {
636 quote! {
637 fn to_prompt(&self) -> String {
638 match self {
639 #(#match_arms),*
640 }
641 }
642 }
643 };
644
645 let expanded = quote! {
646 impl #impl_generics #crate_path::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
647 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
648 vec![#crate_path::prompt::PromptPart::Text(self.to_prompt())]
649 }
650
651 #to_prompt_impl
652
653 fn prompt_schema() -> String {
654 #prompt_string.to_string()
655 }
656 }
657 };
658
659 TokenStream::from(expanded)
660 }
661 Data::Struct(data_struct) => {
662 let mut template_attr = None;
664 let mut template_file_attr = None;
665 let mut mode_attr = None;
666 let mut validate_attr = false;
667 let mut type_marker_attr = false;
668
669 for attr in &input.attrs {
670 if attr.path().is_ident("prompt") {
671 if let Ok(metas) =
673 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
674 {
675 for meta in metas {
676 match meta {
677 Meta::NameValue(nv) if nv.path.is_ident("template") => {
678 if let syn::Expr::Lit(expr_lit) = nv.value
679 && let syn::Lit::Str(lit_str) = expr_lit.lit
680 {
681 template_attr = Some(lit_str.value());
682 }
683 }
684 Meta::NameValue(nv) if nv.path.is_ident("template_file") => {
685 if let syn::Expr::Lit(expr_lit) = nv.value
686 && let syn::Lit::Str(lit_str) = expr_lit.lit
687 {
688 template_file_attr = Some(lit_str.value());
689 }
690 }
691 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
692 if let syn::Expr::Lit(expr_lit) = nv.value
693 && let syn::Lit::Str(lit_str) = expr_lit.lit
694 {
695 mode_attr = Some(lit_str.value());
696 }
697 }
698 Meta::NameValue(nv) if nv.path.is_ident("validate") => {
699 if let syn::Expr::Lit(expr_lit) = nv.value
700 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
701 {
702 validate_attr = lit_bool.value();
703 }
704 }
705 Meta::NameValue(nv) if nv.path.is_ident("type_marker") => {
706 if let syn::Expr::Lit(expr_lit) = nv.value
707 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
708 {
709 type_marker_attr = lit_bool.value();
710 }
711 }
712 Meta::Path(path) if path.is_ident("type_marker") => {
713 type_marker_attr = true;
715 }
716 _ => {}
717 }
718 }
719 }
720 }
721 }
722
723 if template_attr.is_some() && template_file_attr.is_some() {
725 return syn::Error::new(
726 input.ident.span(),
727 "The `template` and `template_file` attributes are mutually exclusive. Please use only one.",
728 ).to_compile_error().into();
729 }
730
731 let template_str = if let Some(file_path) = template_file_attr {
733 let mut full_path = None;
737
738 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
740 let is_trybuild = manifest_dir.contains("target/tests/trybuild");
742
743 if !is_trybuild {
744 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
746 if candidate.exists() {
747 full_path = Some(candidate);
748 }
749 } else {
750 if let Some(target_pos) = manifest_dir.find("/target/tests/trybuild") {
756 let workspace_root = &manifest_dir[..target_pos];
757 let original_macros_dir = std::path::Path::new(workspace_root)
759 .join("crates")
760 .join("llm-toolkit-macros");
761
762 let candidate = original_macros_dir.join(&file_path);
763 if candidate.exists() {
764 full_path = Some(candidate);
765 }
766 }
767 }
768 }
769
770 if full_path.is_none() {
772 let candidate = std::path::Path::new(&file_path).to_path_buf();
773 if candidate.exists() {
774 full_path = Some(candidate);
775 }
776 }
777
778 if full_path.is_none()
781 && let Ok(current_dir) = std::env::current_dir()
782 {
783 let mut search_dir = current_dir.as_path();
784 for _ in 0..10 {
786 let macros_dir = search_dir.join("crates/llm-toolkit-macros");
788 if macros_dir.exists() {
789 let candidate = macros_dir.join(&file_path);
790 if candidate.exists() {
791 full_path = Some(candidate);
792 break;
793 }
794 }
795 let candidate = search_dir.join(&file_path);
797 if candidate.exists() {
798 full_path = Some(candidate);
799 break;
800 }
801 if let Some(parent) = search_dir.parent() {
802 search_dir = parent;
803 } else {
804 break;
805 }
806 }
807 }
808
809 if full_path.is_none() {
811 let mut error_msg = format!(
813 "Template file '{}' not found at compile time.\n\nSearched in:",
814 file_path
815 );
816
817 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
818 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
819 error_msg.push_str(&format!("\n - {}", candidate.display()));
820 }
821
822 if let Ok(current_dir) = std::env::current_dir() {
823 let candidate = current_dir.join(&file_path);
824 error_msg.push_str(&format!("\n - {}", candidate.display()));
825 }
826
827 error_msg.push_str("\n\nPlease ensure:");
828 error_msg.push_str("\n 1. The template file exists");
829 error_msg.push_str("\n 2. The path is relative to CARGO_MANIFEST_DIR");
830 error_msg.push_str("\n 3. There are no typos in the path");
831
832 return syn::Error::new(input.ident.span(), error_msg)
833 .to_compile_error()
834 .into();
835 }
836
837 let final_path = full_path.unwrap();
838
839 match std::fs::read_to_string(&final_path) {
841 Ok(content) => Some(content),
842 Err(e) => {
843 return syn::Error::new(
844 input.ident.span(),
845 format!(
846 "Failed to read template file '{}': {}\n\nPath resolved to: {}",
847 file_path,
848 e,
849 final_path.display()
850 ),
851 )
852 .to_compile_error()
853 .into();
854 }
855 }
856 } else {
857 template_attr
858 };
859
860 if validate_attr && let Some(template) = &template_str {
862 let mut env = minijinja::Environment::new();
864 if let Err(e) = env.add_template("validation", template) {
865 let warning_msg =
867 format!("Template validation warning: Invalid Jinja syntax - {}", e);
868 let warning_ident = syn::Ident::new(
869 "TEMPLATE_VALIDATION_WARNING",
870 proc_macro2::Span::call_site(),
871 );
872 let _warning_tokens = quote! {
873 #[deprecated(note = #warning_msg)]
874 const #warning_ident: () = ();
875 let _ = #warning_ident;
876 };
877 eprintln!("cargo:warning={}", warning_msg);
879 }
880
881 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
883 &fields.named
884 } else {
885 panic!("Template validation is only supported for structs with named fields.");
886 };
887
888 let field_names: std::collections::HashSet<String> = fields
889 .iter()
890 .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
891 .collect();
892
893 let placeholders = parse_template_placeholders_with_mode(template);
895
896 for (placeholder_name, _mode) in &placeholders {
897 if placeholder_name != "self" && !field_names.contains(placeholder_name) {
898 let warning_msg = format!(
899 "Template validation warning: Variable '{}' used in template but not found in struct fields",
900 placeholder_name
901 );
902 eprintln!("cargo:warning={}", warning_msg);
903 }
904 }
905 }
906
907 let name = input.ident;
908 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
909
910 let struct_docs = extract_doc_comments(&input.attrs);
912
913 let is_mode_based =
915 mode_attr.is_some() || (template_str.is_none() && struct_docs.contains("mode"));
916
917 let expanded = if is_mode_based || mode_attr.is_some() {
918 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
920 &fields.named
921 } else {
922 panic!(
923 "Mode-based prompt generation is only supported for structs with named fields."
924 );
925 };
926
927 let struct_name_str = name.to_string();
928
929 let has_default = input.attrs.iter().any(|attr| {
931 if attr.path().is_ident("derive")
932 && let Ok(meta_list) = attr.meta.require_list()
933 {
934 let tokens_str = meta_list.tokens.to_string();
935 tokens_str.contains("Default")
936 } else {
937 false
938 }
939 });
940
941 let schema_parts = generate_schema_only_parts(
952 &struct_name_str,
953 &struct_docs,
954 fields,
955 &crate_path,
956 type_marker_attr,
957 );
958
959 let example_parts = generate_example_only_parts(fields, has_default, &crate_path);
961
962 quote! {
963 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
964 fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<#crate_path::prompt::PromptPart> {
965 match mode {
966 "schema_only" => #schema_parts,
967 "example_only" => #example_parts,
968 "full" | _ => {
969 let mut parts = Vec::new();
971
972 let schema_parts = #schema_parts;
974 parts.extend(schema_parts);
975
976 parts.push(#crate_path::prompt::PromptPart::Text("\n### Example".to_string()));
978 parts.push(#crate_path::prompt::PromptPart::Text(
979 format!("Here is an example of a valid `{}` object:", #struct_name_str)
980 ));
981
982 let example_parts = #example_parts;
984 parts.extend(example_parts);
985
986 parts
987 }
988 }
989 }
990
991 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
992 self.to_prompt_parts_with_mode("full")
993 }
994
995 fn to_prompt(&self) -> String {
996 self.to_prompt_parts()
997 .into_iter()
998 .filter_map(|part| match part {
999 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1000 _ => None,
1001 })
1002 .collect::<Vec<_>>()
1003 .join("\n")
1004 }
1005
1006 fn prompt_schema() -> String {
1007 use std::sync::OnceLock;
1008 static SCHEMA_CACHE: OnceLock<String> = OnceLock::new();
1009
1010 SCHEMA_CACHE.get_or_init(|| {
1011 let schema_parts = #schema_parts;
1012 schema_parts
1013 .into_iter()
1014 .filter_map(|part| match part {
1015 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1016 _ => None,
1017 })
1018 .collect::<Vec<_>>()
1019 .join("\n")
1020 }).clone()
1021 }
1022 }
1023 }
1024 } else if let Some(template) = template_str {
1025 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1028 &fields.named
1029 } else {
1030 panic!(
1031 "Template prompt generation is only supported for structs with named fields."
1032 );
1033 };
1034
1035 let placeholders = parse_template_placeholders_with_mode(&template);
1037 let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
1039 mode.is_some()
1040 && fields
1041 .iter()
1042 .any(|f| f.ident.as_ref().unwrap() == field_name)
1043 });
1044
1045 let mut image_field_parts = Vec::new();
1046 for f in fields.iter() {
1047 let field_name = f.ident.as_ref().unwrap();
1048 let attrs = parse_field_prompt_attrs(&f.attrs);
1049
1050 if attrs.image {
1051 image_field_parts.push(quote! {
1053 parts.extend(self.#field_name.to_prompt_parts());
1054 });
1055 }
1056 }
1057
1058 if has_mode_syntax {
1060 let mut context_fields = Vec::new();
1062 let mut modified_template = template.clone();
1063
1064 for (field_name, mode_opt) in &placeholders {
1066 if let Some(mode) = mode_opt {
1067 let unique_key = format!("{}__{}", field_name, mode);
1069
1070 let pattern = format!("{{{{ {}:{} }}}}", field_name, mode);
1072 let replacement = format!("{{{{ {} }}}}", unique_key);
1073 modified_template = modified_template.replace(&pattern, &replacement);
1074
1075 let field_ident =
1077 syn::Ident::new(field_name, proc_macro2::Span::call_site());
1078
1079 context_fields.push(quote! {
1081 context.insert(
1082 #unique_key.to_string(),
1083 minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
1084 );
1085 });
1086 }
1087 }
1088
1089 for field in fields.iter() {
1091 let field_name = field.ident.as_ref().unwrap();
1092 let field_name_str = field_name.to_string();
1093
1094 let has_mode_entry = placeholders
1096 .iter()
1097 .any(|(name, mode)| name == &field_name_str && mode.is_some());
1098
1099 if !has_mode_entry {
1100 let is_primitive = match &field.ty {
1103 syn::Type::Path(type_path) => {
1104 if let Some(segment) = type_path.path.segments.last() {
1105 let type_name = segment.ident.to_string();
1106 matches!(
1107 type_name.as_str(),
1108 "String"
1109 | "str"
1110 | "i8"
1111 | "i16"
1112 | "i32"
1113 | "i64"
1114 | "i128"
1115 | "isize"
1116 | "u8"
1117 | "u16"
1118 | "u32"
1119 | "u64"
1120 | "u128"
1121 | "usize"
1122 | "f32"
1123 | "f64"
1124 | "bool"
1125 | "char"
1126 )
1127 } else {
1128 false
1129 }
1130 }
1131 _ => false,
1132 };
1133
1134 if is_primitive {
1135 context_fields.push(quote! {
1136 context.insert(
1137 #field_name_str.to_string(),
1138 minijinja::Value::from_serialize(&self.#field_name)
1139 );
1140 });
1141 } else {
1142 context_fields.push(quote! {
1144 context.insert(
1145 #field_name_str.to_string(),
1146 minijinja::Value::from(self.#field_name.to_prompt())
1147 );
1148 });
1149 }
1150 }
1151 }
1152
1153 quote! {
1154 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1155 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1156 let mut parts = Vec::new();
1157
1158 #(#image_field_parts)*
1160
1161 let text = {
1163 let mut env = minijinja::Environment::new();
1164 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1165 panic!("Failed to parse template: {}", e)
1166 });
1167
1168 let tmpl = env.get_template("prompt").unwrap();
1169
1170 let mut context = std::collections::HashMap::new();
1171 #(#context_fields)*
1172
1173 tmpl.render(context).unwrap_or_else(|e| {
1174 format!("Failed to render prompt: {}", e)
1175 })
1176 };
1177
1178 if !text.is_empty() {
1179 parts.push(#crate_path::prompt::PromptPart::Text(text));
1180 }
1181
1182 parts
1183 }
1184
1185 fn to_prompt(&self) -> String {
1186 let mut env = minijinja::Environment::new();
1188 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1189 panic!("Failed to parse template: {}", e)
1190 });
1191
1192 let tmpl = env.get_template("prompt").unwrap();
1193
1194 let mut context = std::collections::HashMap::new();
1195 #(#context_fields)*
1196
1197 tmpl.render(context).unwrap_or_else(|e| {
1198 format!("Failed to render prompt: {}", e)
1199 })
1200 }
1201
1202 fn prompt_schema() -> String {
1203 String::new() }
1205 }
1206 }
1207 } else {
1208 quote! {
1210 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1211 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1212 let mut parts = Vec::new();
1213
1214 #(#image_field_parts)*
1216
1217 let text = #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1219 format!("Failed to render prompt: {}", e)
1220 });
1221 if !text.is_empty() {
1222 parts.push(#crate_path::prompt::PromptPart::Text(text));
1223 }
1224
1225 parts
1226 }
1227
1228 fn to_prompt(&self) -> String {
1229 #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1230 format!("Failed to render prompt: {}", e)
1231 })
1232 }
1233
1234 fn prompt_schema() -> String {
1235 String::new() }
1237 }
1238 }
1239 }
1240 } else {
1241 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1244 &fields.named
1245 } else {
1246 panic!(
1247 "Default prompt generation is only supported for structs with named fields."
1248 );
1249 };
1250
1251 let mut text_field_parts = Vec::new();
1253 let mut image_field_parts = Vec::new();
1254
1255 for f in fields.iter() {
1256 let field_name = f.ident.as_ref().unwrap();
1257 let attrs = parse_field_prompt_attrs(&f.attrs);
1258
1259 if attrs.skip {
1261 continue;
1262 }
1263
1264 if attrs.image {
1265 image_field_parts.push(quote! {
1267 parts.extend(self.#field_name.to_prompt_parts());
1268 });
1269 } else {
1270 let key = if let Some(rename) = attrs.rename {
1276 rename
1277 } else {
1278 let doc_comment = extract_doc_comments(&f.attrs);
1279 if !doc_comment.is_empty() {
1280 doc_comment
1281 } else {
1282 field_name.to_string()
1283 }
1284 };
1285
1286 let value_expr = if let Some(format_with) = attrs.format_with {
1288 let func_path: syn::Path =
1290 syn::parse_str(&format_with).unwrap_or_else(|_| {
1291 panic!("Invalid function path: {}", format_with)
1292 });
1293 quote! { #func_path(&self.#field_name) }
1294 } else {
1295 quote! { self.#field_name.to_prompt() }
1296 };
1297
1298 text_field_parts.push(quote! {
1299 text_parts.push(format!("{}: {}", #key, #value_expr));
1300 });
1301 }
1302 }
1303
1304 quote! {
1306 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1307 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1308 let mut parts = Vec::new();
1309
1310 #(#image_field_parts)*
1312
1313 let mut text_parts = Vec::new();
1315 #(#text_field_parts)*
1316
1317 if !text_parts.is_empty() {
1318 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1319 }
1320
1321 parts
1322 }
1323
1324 fn to_prompt(&self) -> String {
1325 let mut text_parts = Vec::new();
1326 #(#text_field_parts)*
1327 text_parts.join("\n")
1328 }
1329
1330 fn prompt_schema() -> String {
1331 String::new() }
1333 }
1334 }
1335 };
1336
1337 TokenStream::from(expanded)
1338 }
1339 Data::Union(_) => {
1340 panic!("`#[derive(ToPrompt)]` is not supported for unions");
1341 }
1342 }
1343}
1344
1345#[derive(Debug, Clone)]
1347struct TargetInfo {
1348 name: String,
1349 template: Option<String>,
1350 field_configs: std::collections::HashMap<String, FieldTargetConfig>,
1351}
1352
1353#[derive(Debug, Clone, Default)]
1355struct FieldTargetConfig {
1356 skip: bool,
1357 rename: Option<String>,
1358 format_with: Option<String>,
1359 image: bool,
1360 include_only: bool, }
1362
1363fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
1365 let mut configs = Vec::new();
1366
1367 for attr in attrs {
1368 if attr.path().is_ident("prompt_for")
1369 && let Ok(meta_list) = attr.meta.require_list()
1370 {
1371 if meta_list.tokens.to_string() == "skip" {
1373 let config = FieldTargetConfig {
1375 skip: true,
1376 ..Default::default()
1377 };
1378 configs.push(("*".to_string(), config));
1379 } else if let Ok(metas) =
1380 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1381 {
1382 let mut target_name = None;
1383 let mut config = FieldTargetConfig::default();
1384
1385 for meta in metas {
1386 match meta {
1387 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1388 if let syn::Expr::Lit(syn::ExprLit {
1389 lit: syn::Lit::Str(lit_str),
1390 ..
1391 }) = nv.value
1392 {
1393 target_name = Some(lit_str.value());
1394 }
1395 }
1396 Meta::Path(path) if path.is_ident("skip") => {
1397 config.skip = true;
1398 }
1399 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
1400 if let syn::Expr::Lit(syn::ExprLit {
1401 lit: syn::Lit::Str(lit_str),
1402 ..
1403 }) = nv.value
1404 {
1405 config.rename = Some(lit_str.value());
1406 }
1407 }
1408 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
1409 if let syn::Expr::Lit(syn::ExprLit {
1410 lit: syn::Lit::Str(lit_str),
1411 ..
1412 }) = nv.value
1413 {
1414 config.format_with = Some(lit_str.value());
1415 }
1416 }
1417 Meta::Path(path) if path.is_ident("image") => {
1418 config.image = true;
1419 }
1420 _ => {}
1421 }
1422 }
1423
1424 if let Some(name) = target_name {
1425 config.include_only = true;
1426 configs.push((name, config));
1427 }
1428 }
1429 }
1430 }
1431
1432 configs
1433}
1434
1435fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
1437 let mut targets = Vec::new();
1438
1439 for attr in attrs {
1440 if attr.path().is_ident("prompt_for")
1441 && let Ok(meta_list) = attr.meta.require_list()
1442 && let Ok(metas) =
1443 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1444 {
1445 let mut target_name = None;
1446 let mut template = None;
1447
1448 for meta in metas {
1449 match meta {
1450 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1451 if let syn::Expr::Lit(syn::ExprLit {
1452 lit: syn::Lit::Str(lit_str),
1453 ..
1454 }) = nv.value
1455 {
1456 target_name = Some(lit_str.value());
1457 }
1458 }
1459 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1460 if let syn::Expr::Lit(syn::ExprLit {
1461 lit: syn::Lit::Str(lit_str),
1462 ..
1463 }) = nv.value
1464 {
1465 template = Some(lit_str.value());
1466 }
1467 }
1468 _ => {}
1469 }
1470 }
1471
1472 if let Some(name) = target_name {
1473 targets.push(TargetInfo {
1474 name,
1475 template,
1476 field_configs: std::collections::HashMap::new(),
1477 });
1478 }
1479 }
1480 }
1481
1482 targets
1483}
1484
1485#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
1486pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
1487 let input = parse_macro_input!(input as DeriveInput);
1488
1489 let found_crate =
1490 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1491 let crate_path = match found_crate {
1492 FoundCrate::Itself => {
1493 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1495 quote!(::#ident)
1496 }
1497 FoundCrate::Name(name) => {
1498 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1499 quote!(::#ident)
1500 }
1501 };
1502
1503 let data_struct = match &input.data {
1505 Data::Struct(data) => data,
1506 _ => {
1507 return syn::Error::new(
1508 input.ident.span(),
1509 "`#[derive(ToPromptSet)]` is only supported for structs",
1510 )
1511 .to_compile_error()
1512 .into();
1513 }
1514 };
1515
1516 let fields = match &data_struct.fields {
1517 syn::Fields::Named(fields) => &fields.named,
1518 _ => {
1519 return syn::Error::new(
1520 input.ident.span(),
1521 "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
1522 )
1523 .to_compile_error()
1524 .into();
1525 }
1526 };
1527
1528 let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
1530
1531 for field in fields.iter() {
1533 let field_name = field.ident.as_ref().unwrap().to_string();
1534 let field_configs = parse_prompt_for_attrs(&field.attrs);
1535
1536 for (target_name, config) in field_configs {
1537 if target_name == "*" {
1538 for target in &mut targets {
1540 target
1541 .field_configs
1542 .entry(field_name.clone())
1543 .or_insert_with(FieldTargetConfig::default)
1544 .skip = config.skip;
1545 }
1546 } else {
1547 let target_exists = targets.iter().any(|t| t.name == target_name);
1549 if !target_exists {
1550 targets.push(TargetInfo {
1552 name: target_name.clone(),
1553 template: None,
1554 field_configs: std::collections::HashMap::new(),
1555 });
1556 }
1557
1558 let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
1559
1560 target.field_configs.insert(field_name.clone(), config);
1561 }
1562 }
1563 }
1564
1565 let mut match_arms = Vec::new();
1567
1568 for target in &targets {
1569 let target_name = &target.name;
1570
1571 if let Some(template_str) = &target.template {
1572 let mut image_parts = Vec::new();
1574
1575 for field in fields.iter() {
1576 let field_name = field.ident.as_ref().unwrap();
1577 let field_name_str = field_name.to_string();
1578
1579 if let Some(config) = target.field_configs.get(&field_name_str)
1580 && config.image
1581 {
1582 image_parts.push(quote! {
1583 parts.extend(self.#field_name.to_prompt_parts());
1584 });
1585 }
1586 }
1587
1588 match_arms.push(quote! {
1589 #target_name => {
1590 let mut parts = Vec::new();
1591
1592 #(#image_parts)*
1593
1594 let text = #crate_path::prompt::render_prompt(#template_str, self)
1595 .map_err(|e| #crate_path::prompt::PromptSetError::RenderFailed {
1596 target: #target_name.to_string(),
1597 source: e,
1598 })?;
1599
1600 if !text.is_empty() {
1601 parts.push(#crate_path::prompt::PromptPart::Text(text));
1602 }
1603
1604 Ok(parts)
1605 }
1606 });
1607 } else {
1608 let mut text_field_parts = Vec::new();
1610 let mut image_field_parts = Vec::new();
1611
1612 for field in fields.iter() {
1613 let field_name = field.ident.as_ref().unwrap();
1614 let field_name_str = field_name.to_string();
1615
1616 let config = target.field_configs.get(&field_name_str);
1618
1619 if let Some(cfg) = config
1621 && cfg.skip
1622 {
1623 continue;
1624 }
1625
1626 let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
1630 let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
1631 .iter()
1632 .any(|(name, _)| name != "*");
1633
1634 if has_any_target_specific_config && !is_explicitly_for_this_target {
1635 continue;
1636 }
1637
1638 if let Some(cfg) = config {
1639 if cfg.image {
1640 image_field_parts.push(quote! {
1641 parts.extend(self.#field_name.to_prompt_parts());
1642 });
1643 } else {
1644 let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
1645
1646 let value_expr = if let Some(format_with) = &cfg.format_with {
1647 match syn::parse_str::<syn::Path>(format_with) {
1649 Ok(func_path) => quote! { #func_path(&self.#field_name) },
1650 Err(_) => {
1651 let error_msg = format!(
1653 "Invalid function path in format_with: '{}'",
1654 format_with
1655 );
1656 quote! {
1657 compile_error!(#error_msg);
1658 String::new()
1659 }
1660 }
1661 }
1662 } else {
1663 quote! { self.#field_name.to_prompt() }
1664 };
1665
1666 text_field_parts.push(quote! {
1667 text_parts.push(format!("{}: {}", #key, #value_expr));
1668 });
1669 }
1670 } else {
1671 text_field_parts.push(quote! {
1673 text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
1674 });
1675 }
1676 }
1677
1678 match_arms.push(quote! {
1679 #target_name => {
1680 let mut parts = Vec::new();
1681
1682 #(#image_field_parts)*
1683
1684 let mut text_parts = Vec::new();
1685 #(#text_field_parts)*
1686
1687 if !text_parts.is_empty() {
1688 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1689 }
1690
1691 Ok(parts)
1692 }
1693 });
1694 }
1695 }
1696
1697 let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
1699
1700 match_arms.push(quote! {
1702 _ => {
1703 let available = vec![#(#target_names.to_string()),*];
1704 Err(#crate_path::prompt::PromptSetError::TargetNotFound {
1705 target: target.to_string(),
1706 available,
1707 })
1708 }
1709 });
1710
1711 let struct_name = &input.ident;
1712 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1713
1714 let expanded = quote! {
1715 impl #impl_generics #crate_path::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
1716 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<#crate_path::prompt::PromptPart>, #crate_path::prompt::PromptSetError> {
1717 match target {
1718 #(#match_arms)*
1719 }
1720 }
1721 }
1722 };
1723
1724 TokenStream::from(expanded)
1725}
1726
1727struct TypeList {
1729 types: Punctuated<syn::Type, Token![,]>,
1730}
1731
1732impl Parse for TypeList {
1733 fn parse(input: ParseStream) -> syn::Result<Self> {
1734 Ok(TypeList {
1735 types: Punctuated::parse_terminated(input)?,
1736 })
1737 }
1738}
1739
1740#[proc_macro]
1764pub fn examples_section(input: TokenStream) -> TokenStream {
1765 let input = parse_macro_input!(input as TypeList);
1766
1767 let found_crate =
1768 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1769 let _crate_path = match found_crate {
1770 FoundCrate::Itself => quote!(crate),
1771 FoundCrate::Name(name) => {
1772 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1773 quote!(::#ident)
1774 }
1775 };
1776
1777 let mut type_sections = Vec::new();
1779
1780 for ty in input.types.iter() {
1781 let type_name_str = quote!(#ty).to_string();
1783
1784 type_sections.push(quote! {
1786 {
1787 let type_name = #type_name_str;
1788 let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
1789 format!("---\n#### `{}`\n{}", type_name, json_example)
1790 }
1791 });
1792 }
1793
1794 let expanded = quote! {
1796 {
1797 let mut sections = Vec::new();
1798 sections.push("---".to_string());
1799 sections.push("### Examples".to_string());
1800 sections.push("".to_string());
1801 sections.push("Here are examples of the data structures you should use.".to_string());
1802 sections.push("".to_string());
1803
1804 #(sections.push(#type_sections);)*
1805
1806 sections.push("---".to_string());
1807
1808 sections.join("\n")
1809 }
1810 };
1811
1812 TokenStream::from(expanded)
1813}
1814
1815fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
1817 for attr in attrs {
1818 if attr.path().is_ident("prompt_for")
1819 && let Ok(meta_list) = attr.meta.require_list()
1820 && let Ok(metas) =
1821 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1822 {
1823 let mut target_type = None;
1824 let mut template = None;
1825
1826 for meta in metas {
1827 match meta {
1828 Meta::NameValue(nv) if nv.path.is_ident("target") => {
1829 if let syn::Expr::Lit(syn::ExprLit {
1830 lit: syn::Lit::Str(lit_str),
1831 ..
1832 }) = nv.value
1833 {
1834 target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
1836 }
1837 }
1838 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1839 if let syn::Expr::Lit(syn::ExprLit {
1840 lit: syn::Lit::Str(lit_str),
1841 ..
1842 }) = nv.value
1843 {
1844 template = Some(lit_str.value());
1845 }
1846 }
1847 _ => {}
1848 }
1849 }
1850
1851 if let (Some(target), Some(tmpl)) = (target_type, template) {
1852 return (target, tmpl);
1853 }
1854 }
1855 }
1856
1857 panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
1858}
1859
1860#[proc_macro_attribute]
1894pub fn define_intent(_attr: TokenStream, item: TokenStream) -> TokenStream {
1895 let input = parse_macro_input!(item as DeriveInput);
1896
1897 let found_crate =
1898 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1899 let crate_path = match found_crate {
1900 FoundCrate::Itself => {
1901 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1903 quote!(::#ident)
1904 }
1905 FoundCrate::Name(name) => {
1906 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1907 quote!(::#ident)
1908 }
1909 };
1910
1911 let enum_data = match &input.data {
1913 Data::Enum(data) => data,
1914 _ => {
1915 return syn::Error::new(
1916 input.ident.span(),
1917 "`#[define_intent]` can only be applied to enums",
1918 )
1919 .to_compile_error()
1920 .into();
1921 }
1922 };
1923
1924 let mut prompt_template = None;
1926 let mut extractor_tag = None;
1927 let mut mode = None;
1928
1929 for attr in &input.attrs {
1930 if attr.path().is_ident("intent")
1931 && let Ok(metas) =
1932 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1933 {
1934 for meta in metas {
1935 match meta {
1936 Meta::NameValue(nv) if nv.path.is_ident("prompt") => {
1937 if let syn::Expr::Lit(syn::ExprLit {
1938 lit: syn::Lit::Str(lit_str),
1939 ..
1940 }) = nv.value
1941 {
1942 prompt_template = Some(lit_str.value());
1943 }
1944 }
1945 Meta::NameValue(nv) if nv.path.is_ident("extractor_tag") => {
1946 if let syn::Expr::Lit(syn::ExprLit {
1947 lit: syn::Lit::Str(lit_str),
1948 ..
1949 }) = nv.value
1950 {
1951 extractor_tag = Some(lit_str.value());
1952 }
1953 }
1954 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
1955 if let syn::Expr::Lit(syn::ExprLit {
1956 lit: syn::Lit::Str(lit_str),
1957 ..
1958 }) = nv.value
1959 {
1960 mode = Some(lit_str.value());
1961 }
1962 }
1963 _ => {}
1964 }
1965 }
1966 }
1967 }
1968
1969 let mode = mode.unwrap_or_else(|| "single".to_string());
1971
1972 if mode != "single" && mode != "multi_tag" {
1974 return syn::Error::new(
1975 input.ident.span(),
1976 "`mode` must be either \"single\" or \"multi_tag\"",
1977 )
1978 .to_compile_error()
1979 .into();
1980 }
1981
1982 let prompt_template = match prompt_template {
1984 Some(p) => p,
1985 None => {
1986 return syn::Error::new(
1987 input.ident.span(),
1988 "`#[intent(...)]` attribute must include `prompt = \"...\"`",
1989 )
1990 .to_compile_error()
1991 .into();
1992 }
1993 };
1994
1995 if mode == "multi_tag" {
1997 let enum_name = &input.ident;
1998 let actions_doc = generate_multi_tag_actions_doc(&enum_data.variants);
1999 return generate_multi_tag_output(
2000 &input,
2001 enum_name,
2002 enum_data,
2003 prompt_template,
2004 actions_doc,
2005 );
2006 }
2007
2008 let extractor_tag = match extractor_tag {
2010 Some(t) => t,
2011 None => {
2012 return syn::Error::new(
2013 input.ident.span(),
2014 "`#[intent(...)]` attribute must include `extractor_tag = \"...\"`",
2015 )
2016 .to_compile_error()
2017 .into();
2018 }
2019 };
2020
2021 let enum_name = &input.ident;
2023 let enum_docs = extract_doc_comments(&input.attrs);
2024
2025 let mut intents_doc_lines = Vec::new();
2026
2027 if !enum_docs.is_empty() {
2029 intents_doc_lines.push(format!("{}: {}", enum_name, enum_docs));
2030 } else {
2031 intents_doc_lines.push(format!("{}:", enum_name));
2032 }
2033 intents_doc_lines.push(String::new()); intents_doc_lines.push("Possible values:".to_string());
2035
2036 for variant in &enum_data.variants {
2038 let variant_name = &variant.ident;
2039 let variant_docs = extract_doc_comments(&variant.attrs);
2040
2041 if !variant_docs.is_empty() {
2042 intents_doc_lines.push(format!("- {}: {}", variant_name, variant_docs));
2043 } else {
2044 intents_doc_lines.push(format!("- {}", variant_name));
2045 }
2046 }
2047
2048 let intents_doc_str = intents_doc_lines.join("\n");
2049
2050 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2052 let user_variables: Vec<String> = placeholders
2053 .iter()
2054 .filter_map(|(name, _)| {
2055 if name != "intents_doc" {
2056 Some(name.clone())
2057 } else {
2058 None
2059 }
2060 })
2061 .collect();
2062
2063 let enum_name_str = enum_name.to_string();
2065 let snake_case_name = to_snake_case(&enum_name_str);
2066 let function_name = syn::Ident::new(
2067 &format!("build_{}_prompt", snake_case_name),
2068 proc_macro2::Span::call_site(),
2069 );
2070
2071 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2073 .iter()
2074 .map(|var| {
2075 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2076 quote! { #ident: &str }
2077 })
2078 .collect();
2079
2080 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2082 .iter()
2083 .map(|var| {
2084 let var_str = var.clone();
2085 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2086 quote! {
2087 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2088 }
2089 })
2090 .collect();
2091
2092 let converted_template = prompt_template.clone();
2094
2095 let extractor_name = syn::Ident::new(
2097 &format!("{}Extractor", enum_name),
2098 proc_macro2::Span::call_site(),
2099 );
2100
2101 let filtered_attrs: Vec<_> = input
2103 .attrs
2104 .iter()
2105 .filter(|attr| !attr.path().is_ident("intent"))
2106 .collect();
2107
2108 let vis = &input.vis;
2110 let generics = &input.generics;
2111 let variants = &enum_data.variants;
2112 let enum_output = quote! {
2113 #(#filtered_attrs)*
2114 #vis enum #enum_name #generics {
2115 #variants
2116 }
2117 };
2118
2119 let expanded = quote! {
2121 #enum_output
2123
2124 pub fn #function_name(#(#function_params),*) -> String {
2126 let mut env = minijinja::Environment::new();
2127 env.add_template("prompt", #converted_template)
2128 .expect("Failed to parse intent prompt template");
2129
2130 let tmpl = env.get_template("prompt").unwrap();
2131
2132 let mut __template_context = std::collections::HashMap::new();
2133
2134 __template_context.insert("intents_doc".to_string(), minijinja::Value::from(#intents_doc_str));
2136
2137 #(#context_insertions)*
2139
2140 tmpl.render(&__template_context)
2141 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2142 }
2143
2144 pub struct #extractor_name;
2146
2147 impl #extractor_name {
2148 pub const EXTRACTOR_TAG: &'static str = #extractor_tag;
2149 }
2150
2151 impl #crate_path::intent::IntentExtractor<#enum_name> for #extractor_name {
2152 fn extract_intent(&self, response: &str) -> Result<#enum_name, #crate_path::intent::IntentExtractionError> {
2153 #crate_path::intent::extract_intent_from_response(response, Self::EXTRACTOR_TAG)
2155 }
2156 }
2157 };
2158
2159 TokenStream::from(expanded)
2160}
2161
2162fn to_snake_case(s: &str) -> String {
2164 let mut result = String::new();
2165 let mut prev_upper = false;
2166
2167 for (i, ch) in s.chars().enumerate() {
2168 if ch.is_uppercase() {
2169 if i > 0 && !prev_upper {
2170 result.push('_');
2171 }
2172 result.push(ch.to_lowercase().next().unwrap());
2173 prev_upper = true;
2174 } else {
2175 result.push(ch);
2176 prev_upper = false;
2177 }
2178 }
2179
2180 result
2181}
2182
2183#[derive(Debug, Default)]
2185struct ActionAttrs {
2186 tag: Option<String>,
2187}
2188
2189fn parse_action_attrs(attrs: &[syn::Attribute]) -> ActionAttrs {
2190 let mut result = ActionAttrs::default();
2191
2192 for attr in attrs {
2193 if attr.path().is_ident("action")
2194 && let Ok(meta_list) = attr.meta.require_list()
2195 && let Ok(metas) =
2196 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2197 {
2198 for meta in metas {
2199 if let Meta::NameValue(nv) = meta
2200 && nv.path.is_ident("tag")
2201 && let syn::Expr::Lit(syn::ExprLit {
2202 lit: syn::Lit::Str(lit_str),
2203 ..
2204 }) = nv.value
2205 {
2206 result.tag = Some(lit_str.value());
2207 }
2208 }
2209 }
2210 }
2211
2212 result
2213}
2214
2215#[derive(Debug, Default)]
2217struct FieldActionAttrs {
2218 is_attribute: bool,
2219 is_inner_text: bool,
2220}
2221
2222fn parse_field_action_attrs(attrs: &[syn::Attribute]) -> FieldActionAttrs {
2223 let mut result = FieldActionAttrs::default();
2224
2225 for attr in attrs {
2226 if attr.path().is_ident("action")
2227 && let Ok(meta_list) = attr.meta.require_list()
2228 {
2229 let tokens_str = meta_list.tokens.to_string();
2230 if tokens_str == "attribute" {
2231 result.is_attribute = true;
2232 } else if tokens_str == "inner_text" {
2233 result.is_inner_text = true;
2234 }
2235 }
2236 }
2237
2238 result
2239}
2240
2241fn generate_multi_tag_actions_doc(
2243 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2244) -> String {
2245 let mut doc_lines = Vec::new();
2246
2247 for variant in variants {
2248 let action_attrs = parse_action_attrs(&variant.attrs);
2249
2250 if let Some(tag) = action_attrs.tag {
2251 let variant_docs = extract_doc_comments(&variant.attrs);
2252
2253 match &variant.fields {
2254 syn::Fields::Unit => {
2255 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2257 }
2258 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
2259 doc_lines.push(format!("- `<{}>...</{}>`: {}", tag, tag, variant_docs));
2261 }
2262 syn::Fields::Named(fields) => {
2263 let mut attrs_str = Vec::new();
2265 let mut has_inner_text = false;
2266
2267 for field in &fields.named {
2268 let field_name = field.ident.as_ref().unwrap();
2269 let field_attrs = parse_field_action_attrs(&field.attrs);
2270
2271 if field_attrs.is_attribute {
2272 attrs_str.push(format!("{}=\"...\"", field_name));
2273 } else if field_attrs.is_inner_text {
2274 has_inner_text = true;
2275 }
2276 }
2277
2278 let attrs_part = if !attrs_str.is_empty() {
2279 format!(" {}", attrs_str.join(" "))
2280 } else {
2281 String::new()
2282 };
2283
2284 if has_inner_text {
2285 doc_lines.push(format!(
2286 "- `<{}{}>...</{}>`: {}",
2287 tag, attrs_part, tag, variant_docs
2288 ));
2289 } else if !attrs_str.is_empty() {
2290 doc_lines.push(format!("- `<{}{} />`: {}", tag, attrs_part, variant_docs));
2291 } else {
2292 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2293 }
2294
2295 for field in &fields.named {
2297 let field_name = field.ident.as_ref().unwrap();
2298 let field_attrs = parse_field_action_attrs(&field.attrs);
2299 let field_docs = extract_doc_comments(&field.attrs);
2300
2301 if field_attrs.is_attribute {
2302 doc_lines
2303 .push(format!(" - `{}` (attribute): {}", field_name, field_docs));
2304 } else if field_attrs.is_inner_text {
2305 doc_lines
2306 .push(format!(" - `{}` (inner_text): {}", field_name, field_docs));
2307 }
2308 }
2309 }
2310 _ => {
2311 }
2313 }
2314 }
2315 }
2316
2317 doc_lines.join("\n")
2318}
2319
2320fn generate_tags_regex(
2322 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2323) -> String {
2324 let mut tag_names = Vec::new();
2325
2326 for variant in variants {
2327 let action_attrs = parse_action_attrs(&variant.attrs);
2328 if let Some(tag) = action_attrs.tag {
2329 tag_names.push(tag);
2330 }
2331 }
2332
2333 if tag_names.is_empty() {
2334 return String::new();
2335 }
2336
2337 let tags_pattern = tag_names.join("|");
2338 format!(
2341 r"(?is)<(?:{})\b[^>]*/>|<(?:{})\b[^>]*>.*?</(?:{})>",
2342 tags_pattern, tags_pattern, tags_pattern
2343 )
2344}
2345
2346fn generate_multi_tag_output(
2348 input: &DeriveInput,
2349 enum_name: &syn::Ident,
2350 enum_data: &syn::DataEnum,
2351 prompt_template: String,
2352 actions_doc: String,
2353) -> TokenStream {
2354 let found_crate =
2355 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2356 let crate_path = match found_crate {
2357 FoundCrate::Itself => {
2358 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2360 quote!(::#ident)
2361 }
2362 FoundCrate::Name(name) => {
2363 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2364 quote!(::#ident)
2365 }
2366 };
2367
2368 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2370 let user_variables: Vec<String> = placeholders
2371 .iter()
2372 .filter_map(|(name, _)| {
2373 if name != "actions_doc" {
2374 Some(name.clone())
2375 } else {
2376 None
2377 }
2378 })
2379 .collect();
2380
2381 let enum_name_str = enum_name.to_string();
2383 let snake_case_name = to_snake_case(&enum_name_str);
2384 let function_name = syn::Ident::new(
2385 &format!("build_{}_prompt", snake_case_name),
2386 proc_macro2::Span::call_site(),
2387 );
2388
2389 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2391 .iter()
2392 .map(|var| {
2393 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2394 quote! { #ident: &str }
2395 })
2396 .collect();
2397
2398 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2400 .iter()
2401 .map(|var| {
2402 let var_str = var.clone();
2403 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2404 quote! {
2405 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2406 }
2407 })
2408 .collect();
2409
2410 let extractor_name = syn::Ident::new(
2412 &format!("{}Extractor", enum_name),
2413 proc_macro2::Span::call_site(),
2414 );
2415
2416 let filtered_attrs: Vec<_> = input
2418 .attrs
2419 .iter()
2420 .filter(|attr| !attr.path().is_ident("intent"))
2421 .collect();
2422
2423 let filtered_variants: Vec<proc_macro2::TokenStream> = enum_data
2425 .variants
2426 .iter()
2427 .map(|variant| {
2428 let variant_name = &variant.ident;
2429 let variant_attrs: Vec<_> = variant
2430 .attrs
2431 .iter()
2432 .filter(|attr| !attr.path().is_ident("action"))
2433 .collect();
2434 let fields = &variant.fields;
2435
2436 let filtered_fields = match fields {
2438 syn::Fields::Named(named_fields) => {
2439 let filtered: Vec<_> = named_fields
2440 .named
2441 .iter()
2442 .map(|field| {
2443 let field_name = &field.ident;
2444 let field_type = &field.ty;
2445 let field_vis = &field.vis;
2446 let filtered_attrs: Vec<_> = field
2447 .attrs
2448 .iter()
2449 .filter(|attr| !attr.path().is_ident("action"))
2450 .collect();
2451 quote! {
2452 #(#filtered_attrs)*
2453 #field_vis #field_name: #field_type
2454 }
2455 })
2456 .collect();
2457 quote! { { #(#filtered,)* } }
2458 }
2459 syn::Fields::Unnamed(unnamed_fields) => {
2460 let types: Vec<_> = unnamed_fields
2461 .unnamed
2462 .iter()
2463 .map(|field| {
2464 let field_type = &field.ty;
2465 quote! { #field_type }
2466 })
2467 .collect();
2468 quote! { (#(#types),*) }
2469 }
2470 syn::Fields::Unit => quote! {},
2471 };
2472
2473 quote! {
2474 #(#variant_attrs)*
2475 #variant_name #filtered_fields
2476 }
2477 })
2478 .collect();
2479
2480 let vis = &input.vis;
2481 let generics = &input.generics;
2482
2483 let parsing_arms = generate_parsing_arms(&enum_data.variants, enum_name);
2485
2486 let tags_regex = generate_tags_regex(&enum_data.variants);
2488
2489 let expanded = quote! {
2490 #(#filtered_attrs)*
2492 #vis enum #enum_name #generics {
2493 #(#filtered_variants),*
2494 }
2495
2496 pub fn #function_name(#(#function_params),*) -> String {
2498 let mut env = minijinja::Environment::new();
2499 env.add_template("prompt", #prompt_template)
2500 .expect("Failed to parse intent prompt template");
2501
2502 let tmpl = env.get_template("prompt").unwrap();
2503
2504 let mut __template_context = std::collections::HashMap::new();
2505
2506 __template_context.insert("actions_doc".to_string(), minijinja::Value::from(#actions_doc));
2508
2509 #(#context_insertions)*
2511
2512 tmpl.render(&__template_context)
2513 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2514 }
2515
2516 pub struct #extractor_name;
2518
2519 impl #extractor_name {
2520 fn parse_single_action(&self, text: &str) -> Option<#enum_name> {
2521 use ::quick_xml::events::Event;
2522 use ::quick_xml::Reader;
2523
2524 let mut actions = Vec::new();
2525 let mut reader = Reader::from_str(text);
2526 reader.config_mut().trim_text(true);
2527
2528 let mut buf = Vec::new();
2529
2530 loop {
2531 match reader.read_event_into(&mut buf) {
2532 Ok(Event::Start(e)) => {
2533 let owned_e = e.into_owned();
2534 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2535 let is_empty = false;
2536
2537 #parsing_arms
2538 }
2539 Ok(Event::Empty(e)) => {
2540 let owned_e = e.into_owned();
2541 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2542 let is_empty = true;
2543
2544 #parsing_arms
2545 }
2546 Ok(Event::Eof) => break,
2547 Err(_) => {
2548 break;
2550 }
2551 _ => {}
2552 }
2553 buf.clear();
2554 }
2555
2556 actions.into_iter().next()
2557 }
2558
2559 pub fn extract_actions(&self, text: &str) -> Result<Vec<#enum_name>, #crate_path::intent::IntentError> {
2560 use ::quick_xml::events::Event;
2561 use ::quick_xml::Reader;
2562
2563 let mut actions = Vec::new();
2564 let mut reader = Reader::from_str(text);
2565 reader.config_mut().trim_text(true);
2566
2567 let mut buf = Vec::new();
2568
2569 loop {
2570 match reader.read_event_into(&mut buf) {
2571 Ok(Event::Start(e)) => {
2572 let owned_e = e.into_owned();
2573 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2574 let is_empty = false;
2575
2576 #parsing_arms
2577 }
2578 Ok(Event::Empty(e)) => {
2579 let owned_e = e.into_owned();
2580 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2581 let is_empty = true;
2582
2583 #parsing_arms
2584 }
2585 Ok(Event::Eof) => break,
2586 Err(_) => {
2587 break;
2589 }
2590 _ => {}
2591 }
2592 buf.clear();
2593 }
2594
2595 Ok(actions)
2596 }
2597
2598 pub fn transform_actions<F>(&self, text: &str, mut transformer: F) -> String
2599 where
2600 F: FnMut(#enum_name) -> String,
2601 {
2602 use ::regex::Regex;
2603
2604 let regex_pattern = #tags_regex;
2605 if regex_pattern.is_empty() {
2606 return text.to_string();
2607 }
2608
2609 let re = Regex::new(®ex_pattern).unwrap_or_else(|e| {
2610 panic!("Failed to compile regex for action tags: {}", e);
2611 });
2612
2613 re.replace_all(text, |caps: &::regex::Captures| {
2614 let matched = caps.get(0).map(|m| m.as_str()).unwrap_or("");
2615
2616 if let Some(action) = self.parse_single_action(matched) {
2618 transformer(action)
2619 } else {
2620 matched.to_string()
2622 }
2623 }).to_string()
2624 }
2625
2626 pub fn strip_actions(&self, text: &str) -> String {
2627 self.transform_actions(text, |_| String::new())
2628 }
2629 }
2630 };
2631
2632 TokenStream::from(expanded)
2633}
2634
2635fn generate_parsing_arms(
2637 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2638 enum_name: &syn::Ident,
2639) -> proc_macro2::TokenStream {
2640 let mut arms = Vec::new();
2641
2642 for variant in variants {
2643 let variant_name = &variant.ident;
2644 let action_attrs = parse_action_attrs(&variant.attrs);
2645
2646 if let Some(tag) = action_attrs.tag {
2647 match &variant.fields {
2648 syn::Fields::Unit => {
2649 arms.push(quote! {
2651 if &tag_name == #tag {
2652 actions.push(#enum_name::#variant_name);
2653 }
2654 });
2655 }
2656 syn::Fields::Unnamed(_fields) => {
2657 arms.push(quote! {
2659 if &tag_name == #tag && !is_empty {
2660 match reader.read_text(owned_e.name()) {
2662 Ok(text) => {
2663 actions.push(#enum_name::#variant_name(text.to_string()));
2664 }
2665 Err(_) => {
2666 actions.push(#enum_name::#variant_name(String::new()));
2668 }
2669 }
2670 }
2671 });
2672 }
2673 syn::Fields::Named(fields) => {
2674 let mut field_names = Vec::new();
2676 let mut has_inner_text_field = None;
2677
2678 for field in &fields.named {
2679 let field_name = field.ident.as_ref().unwrap();
2680 let field_attrs = parse_field_action_attrs(&field.attrs);
2681
2682 if field_attrs.is_attribute {
2683 field_names.push(field_name.clone());
2684 } else if field_attrs.is_inner_text {
2685 has_inner_text_field = Some(field_name.clone());
2686 }
2687 }
2688
2689 if let Some(inner_text_field) = has_inner_text_field {
2690 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2693 quote! {
2694 let mut #field_name = String::new();
2695 for attr in owned_e.attributes() {
2696 if let Ok(attr) = attr {
2697 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2698 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2699 break;
2700 }
2701 }
2702 }
2703 }
2704 }).collect();
2705
2706 arms.push(quote! {
2707 if &tag_name == #tag {
2708 #(#attr_extractions)*
2709
2710 if is_empty {
2712 let #inner_text_field = String::new();
2713 actions.push(#enum_name::#variant_name {
2714 #(#field_names,)*
2715 #inner_text_field,
2716 });
2717 } else {
2718 match reader.read_text(owned_e.name()) {
2720 Ok(text) => {
2721 let #inner_text_field = text.to_string();
2722 actions.push(#enum_name::#variant_name {
2723 #(#field_names,)*
2724 #inner_text_field,
2725 });
2726 }
2727 Err(_) => {
2728 let #inner_text_field = String::new();
2730 actions.push(#enum_name::#variant_name {
2731 #(#field_names,)*
2732 #inner_text_field,
2733 });
2734 }
2735 }
2736 }
2737 }
2738 });
2739 } else {
2740 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2742 quote! {
2743 let mut #field_name = String::new();
2744 for attr in owned_e.attributes() {
2745 if let Ok(attr) = attr {
2746 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2747 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2748 break;
2749 }
2750 }
2751 }
2752 }
2753 }).collect();
2754
2755 arms.push(quote! {
2756 if &tag_name == #tag {
2757 #(#attr_extractions)*
2758 actions.push(#enum_name::#variant_name {
2759 #(#field_names),*
2760 });
2761 }
2762 });
2763 }
2764 }
2765 }
2766 }
2767 }
2768
2769 quote! {
2770 #(#arms)*
2771 }
2772}
2773
2774#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
2776pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
2777 let input = parse_macro_input!(input as DeriveInput);
2778
2779 let found_crate =
2780 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2781 let crate_path = match found_crate {
2782 FoundCrate::Itself => {
2783 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2785 quote!(::#ident)
2786 }
2787 FoundCrate::Name(name) => {
2788 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2789 quote!(::#ident)
2790 }
2791 };
2792
2793 let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
2795
2796 let struct_name = &input.ident;
2797 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2798
2799 let placeholders = parse_template_placeholders_with_mode(&template);
2801
2802 let mut converted_template = template.clone();
2804 let mut context_fields = Vec::new();
2805
2806 let fields = match &input.data {
2808 Data::Struct(data_struct) => match &data_struct.fields {
2809 syn::Fields::Named(fields) => &fields.named,
2810 _ => panic!("ToPromptFor is only supported for structs with named fields"),
2811 },
2812 _ => panic!("ToPromptFor is only supported for structs"),
2813 };
2814
2815 let has_mode_support = input.attrs.iter().any(|attr| {
2817 if attr.path().is_ident("prompt")
2818 && let Ok(metas) =
2819 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2820 {
2821 for meta in metas {
2822 if let Meta::NameValue(nv) = meta
2823 && nv.path.is_ident("mode")
2824 {
2825 return true;
2826 }
2827 }
2828 }
2829 false
2830 });
2831
2832 for (placeholder_name, mode_opt) in &placeholders {
2834 if placeholder_name == "self" {
2835 if let Some(specific_mode) = mode_opt {
2836 let unique_key = format!("self__{}", specific_mode);
2838
2839 let pattern = format!("{{{{ self:{} }}}}", specific_mode);
2841 let replacement = format!("{{{{ {} }}}}", unique_key);
2842 converted_template = converted_template.replace(&pattern, &replacement);
2843
2844 context_fields.push(quote! {
2846 context.insert(
2847 #unique_key.to_string(),
2848 minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
2849 );
2850 });
2851 } else {
2852 if has_mode_support {
2855 context_fields.push(quote! {
2857 context.insert(
2858 "self".to_string(),
2859 minijinja::Value::from(self.to_prompt_with_mode(mode))
2860 );
2861 });
2862 } else {
2863 context_fields.push(quote! {
2865 context.insert(
2866 "self".to_string(),
2867 minijinja::Value::from(self.to_prompt())
2868 );
2869 });
2870 }
2871 }
2872 } else {
2873 let field_exists = fields.iter().any(|f| {
2876 f.ident
2877 .as_ref()
2878 .is_some_and(|ident| ident == placeholder_name)
2879 });
2880
2881 if field_exists {
2882 let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
2883
2884 context_fields.push(quote! {
2888 context.insert(
2889 #placeholder_name.to_string(),
2890 minijinja::Value::from_serialize(&self.#field_ident)
2891 );
2892 });
2893 }
2894 }
2896 }
2897
2898 let expanded = quote! {
2899 impl #impl_generics #crate_path::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
2900 where
2901 #target_type: serde::Serialize,
2902 {
2903 fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
2904 let mut env = minijinja::Environment::new();
2906 env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
2907 panic!("Failed to parse template: {}", e)
2908 });
2909
2910 let tmpl = env.get_template("prompt").unwrap();
2911
2912 let mut context = std::collections::HashMap::new();
2914 context.insert(
2916 "self".to_string(),
2917 minijinja::Value::from_serialize(self)
2918 );
2919 context.insert(
2921 "target".to_string(),
2922 minijinja::Value::from_serialize(target)
2923 );
2924 #(#context_fields)*
2925
2926 tmpl.render(context).unwrap_or_else(|e| {
2928 format!("Failed to render prompt: {}", e)
2929 })
2930 }
2931 }
2932 };
2933
2934 TokenStream::from(expanded)
2935}
2936
2937struct AgentAttrs {
2943 expertise: Option<String>,
2944 output: Option<syn::Type>,
2945 backend: Option<String>,
2946 model: Option<String>,
2947 inner: Option<String>,
2948 default_inner: Option<String>,
2949 max_retries: Option<u32>,
2950 profile: Option<String>,
2951}
2952
2953impl Parse for AgentAttrs {
2954 fn parse(input: ParseStream) -> syn::Result<Self> {
2955 let mut expertise = None;
2956 let mut output = None;
2957 let mut backend = None;
2958 let mut model = None;
2959 let mut inner = None;
2960 let mut default_inner = None;
2961 let mut max_retries = None;
2962 let mut profile = None;
2963
2964 let pairs = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
2965
2966 for meta in pairs {
2967 match meta {
2968 Meta::NameValue(nv) if nv.path.is_ident("expertise") => {
2969 if let syn::Expr::Lit(syn::ExprLit {
2970 lit: syn::Lit::Str(lit_str),
2971 ..
2972 }) = &nv.value
2973 {
2974 expertise = Some(lit_str.value());
2975 }
2976 }
2977 Meta::NameValue(nv) if nv.path.is_ident("output") => {
2978 if let syn::Expr::Lit(syn::ExprLit {
2979 lit: syn::Lit::Str(lit_str),
2980 ..
2981 }) = &nv.value
2982 {
2983 let ty: syn::Type = syn::parse_str(&lit_str.value())?;
2984 output = Some(ty);
2985 }
2986 }
2987 Meta::NameValue(nv) if nv.path.is_ident("backend") => {
2988 if let syn::Expr::Lit(syn::ExprLit {
2989 lit: syn::Lit::Str(lit_str),
2990 ..
2991 }) = &nv.value
2992 {
2993 backend = Some(lit_str.value());
2994 }
2995 }
2996 Meta::NameValue(nv) if nv.path.is_ident("model") => {
2997 if let syn::Expr::Lit(syn::ExprLit {
2998 lit: syn::Lit::Str(lit_str),
2999 ..
3000 }) = &nv.value
3001 {
3002 model = Some(lit_str.value());
3003 }
3004 }
3005 Meta::NameValue(nv) if nv.path.is_ident("inner") => {
3006 if let syn::Expr::Lit(syn::ExprLit {
3007 lit: syn::Lit::Str(lit_str),
3008 ..
3009 }) = &nv.value
3010 {
3011 inner = Some(lit_str.value());
3012 }
3013 }
3014 Meta::NameValue(nv) if nv.path.is_ident("default_inner") => {
3015 if let syn::Expr::Lit(syn::ExprLit {
3016 lit: syn::Lit::Str(lit_str),
3017 ..
3018 }) = &nv.value
3019 {
3020 default_inner = Some(lit_str.value());
3021 }
3022 }
3023 Meta::NameValue(nv) if nv.path.is_ident("max_retries") => {
3024 if let syn::Expr::Lit(syn::ExprLit {
3025 lit: syn::Lit::Int(lit_int),
3026 ..
3027 }) = &nv.value
3028 {
3029 max_retries = Some(lit_int.base10_parse()?);
3030 }
3031 }
3032 Meta::NameValue(nv) if nv.path.is_ident("profile") => {
3033 if let syn::Expr::Lit(syn::ExprLit {
3034 lit: syn::Lit::Str(lit_str),
3035 ..
3036 }) = &nv.value
3037 {
3038 profile = Some(lit_str.value());
3039 }
3040 }
3041 _ => {}
3042 }
3043 }
3044
3045 Ok(AgentAttrs {
3046 expertise,
3047 output,
3048 backend,
3049 model,
3050 inner,
3051 default_inner,
3052 max_retries,
3053 profile,
3054 })
3055 }
3056}
3057
3058fn parse_agent_attrs(attrs: &[syn::Attribute]) -> syn::Result<AgentAttrs> {
3060 for attr in attrs {
3061 if attr.path().is_ident("agent") {
3062 return attr.parse_args::<AgentAttrs>();
3063 }
3064 }
3065
3066 Ok(AgentAttrs {
3067 expertise: None,
3068 output: None,
3069 backend: None,
3070 model: None,
3071 inner: None,
3072 default_inner: None,
3073 max_retries: None,
3074 profile: None,
3075 })
3076}
3077
3078fn generate_backend_constructors(
3080 struct_name: &syn::Ident,
3081 backend: &str,
3082 _model: Option<&str>,
3083 _profile: Option<&str>,
3084 crate_path: &proc_macro2::TokenStream,
3085) -> proc_macro2::TokenStream {
3086 match backend {
3087 "claude" => {
3088 quote! {
3089 impl #struct_name {
3090 pub fn with_claude() -> Self {
3092 Self::new(#crate_path::agent::impls::ClaudeCodeAgent::new())
3093 }
3094
3095 pub fn with_claude_model(model: &str) -> Self {
3097 Self::new(
3098 #crate_path::agent::impls::ClaudeCodeAgent::new()
3099 .with_model_str(model)
3100 )
3101 }
3102 }
3103 }
3104 }
3105 "gemini" => {
3106 quote! {
3107 impl #struct_name {
3108 pub fn with_gemini() -> Self {
3110 Self::new(#crate_path::agent::impls::GeminiAgent::new())
3111 }
3112
3113 pub fn with_gemini_model(model: &str) -> Self {
3115 Self::new(
3116 #crate_path::agent::impls::GeminiAgent::new()
3117 .with_model_str(model)
3118 )
3119 }
3120 }
3121 }
3122 }
3123 _ => quote! {},
3124 }
3125}
3126
3127fn generate_default_impl(
3129 struct_name: &syn::Ident,
3130 backend: &str,
3131 model: Option<&str>,
3132 profile: Option<&str>,
3133 crate_path: &proc_macro2::TokenStream,
3134) -> proc_macro2::TokenStream {
3135 let profile_expr = if let Some(profile_str) = profile {
3137 match profile_str.to_lowercase().as_str() {
3138 "creative" => quote! { #crate_path::agent::ExecutionProfile::Creative },
3139 "balanced" => quote! { #crate_path::agent::ExecutionProfile::Balanced },
3140 "deterministic" => quote! { #crate_path::agent::ExecutionProfile::Deterministic },
3141 _ => quote! { #crate_path::agent::ExecutionProfile::Balanced }, }
3143 } else {
3144 quote! { #crate_path::agent::ExecutionProfile::default() }
3145 };
3146
3147 let agent_init = match backend {
3148 "gemini" => {
3149 let mut builder = quote! { #crate_path::agent::impls::GeminiAgent::new() };
3150
3151 if let Some(model_str) = model {
3152 builder = quote! { #builder.with_model_str(#model_str) };
3153 }
3154
3155 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3156 builder
3157 }
3158 _ => {
3159 let mut builder = quote! { #crate_path::agent::impls::ClaudeCodeAgent::new() };
3161
3162 if let Some(model_str) = model {
3163 builder = quote! { #builder.with_model_str(#model_str) };
3164 }
3165
3166 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3167 builder
3168 }
3169 };
3170
3171 quote! {
3172 impl Default for #struct_name {
3173 fn default() -> Self {
3174 Self::new(#agent_init)
3175 }
3176 }
3177 }
3178}
3179
3180#[proc_macro_derive(Agent, attributes(agent))]
3189pub fn derive_agent(input: TokenStream) -> TokenStream {
3190 let input = parse_macro_input!(input as DeriveInput);
3191 let struct_name = &input.ident;
3192
3193 let agent_attrs = match parse_agent_attrs(&input.attrs) {
3195 Ok(attrs) => attrs,
3196 Err(e) => return e.to_compile_error().into(),
3197 };
3198
3199 let expertise = agent_attrs
3200 .expertise
3201 .unwrap_or_else(|| String::from("general AI assistant"));
3202 let output_type = agent_attrs
3203 .output
3204 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3205 let backend = agent_attrs
3206 .backend
3207 .unwrap_or_else(|| String::from("claude"));
3208 let model = agent_attrs.model;
3209 let _profile = agent_attrs.profile; let max_retries = agent_attrs.max_retries.unwrap_or(3); let found_crate =
3214 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3215 let crate_path = match found_crate {
3216 FoundCrate::Itself => {
3217 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3219 quote!(::#ident)
3220 }
3221 FoundCrate::Name(name) => {
3222 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3223 quote!(::#ident)
3224 }
3225 };
3226
3227 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
3228
3229 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3231 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3232
3233 let enhanced_expertise = if is_string_output {
3235 quote! { #expertise }
3237 } else {
3238 let type_name = quote!(#output_type).to_string();
3240 quote! {
3241 {
3242 use std::sync::OnceLock;
3243 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3244
3245 EXPERTISE_CACHE.get_or_init(|| {
3246 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3248
3249 if schema.is_empty() {
3250 format!(
3252 concat!(
3253 #expertise,
3254 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3255 "Do not include any text outside the JSON object."
3256 ),
3257 #type_name
3258 )
3259 } else {
3260 format!(
3262 concat!(
3263 #expertise,
3264 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3265 ),
3266 schema
3267 )
3268 }
3269 }).as_str()
3270 }
3271 }
3272 };
3273
3274 let agent_init = match backend.as_str() {
3276 "gemini" => {
3277 if let Some(model_str) = model {
3278 quote! {
3279 use #crate_path::agent::impls::GeminiAgent;
3280 let agent = GeminiAgent::new().with_model_str(#model_str);
3281 }
3282 } else {
3283 quote! {
3284 use #crate_path::agent::impls::GeminiAgent;
3285 let agent = GeminiAgent::new();
3286 }
3287 }
3288 }
3289 "claude" => {
3290 if let Some(model_str) = model {
3291 quote! {
3292 use #crate_path::agent::impls::ClaudeCodeAgent;
3293 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3294 }
3295 } else {
3296 quote! {
3297 use #crate_path::agent::impls::ClaudeCodeAgent;
3298 let agent = ClaudeCodeAgent::new();
3299 }
3300 }
3301 }
3302 _ => {
3303 if let Some(model_str) = model {
3305 quote! {
3306 use #crate_path::agent::impls::ClaudeCodeAgent;
3307 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3308 }
3309 } else {
3310 quote! {
3311 use #crate_path::agent::impls::ClaudeCodeAgent;
3312 let agent = ClaudeCodeAgent::new();
3313 }
3314 }
3315 }
3316 };
3317
3318 let expanded = quote! {
3319 #[async_trait::async_trait]
3320 impl #impl_generics #crate_path::agent::Agent for #struct_name #ty_generics #where_clause {
3321 type Output = #output_type;
3322
3323 fn expertise(&self) -> &str {
3324 #enhanced_expertise
3325 }
3326
3327 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3328 #agent_init
3330
3331 let agent_ref = &agent;
3333 #crate_path::agent::retry::retry_execution(
3334 #max_retries,
3335 &intent,
3336 move |payload| {
3337 let payload = payload.clone();
3338 async move {
3339 let response = agent_ref.execute(payload).await?;
3341
3342 let json_str = #crate_path::extract_json(&response)
3344 .map_err(|e| #crate_path::agent::AgentError::ParseError {
3345 message: format!("Failed to extract JSON: {}", e),
3346 reason: #crate_path::agent::error::ParseErrorReason::MarkdownExtractionFailed,
3347 })?;
3348
3349 serde_json::from_str::<Self::Output>(&json_str)
3351 .map_err(|e| {
3352 let reason = if e.is_eof() {
3354 #crate_path::agent::error::ParseErrorReason::UnexpectedEof
3355 } else if e.is_syntax() {
3356 #crate_path::agent::error::ParseErrorReason::InvalidJson
3357 } else {
3358 #crate_path::agent::error::ParseErrorReason::SchemaMismatch
3359 };
3360
3361 #crate_path::agent::AgentError::ParseError {
3362 message: format!("Failed to parse JSON: {}", e),
3363 reason,
3364 }
3365 })
3366 }
3367 }
3368 ).await
3369 }
3370
3371 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3372 #agent_init
3374 agent.is_available().await
3375 }
3376 }
3377 };
3378
3379 TokenStream::from(expanded)
3380}
3381
3382#[proc_macro_attribute]
3397pub fn agent(attr: TokenStream, item: TokenStream) -> TokenStream {
3398 let agent_attrs = match syn::parse::<AgentAttrs>(attr) {
3400 Ok(attrs) => attrs,
3401 Err(e) => return e.to_compile_error().into(),
3402 };
3403
3404 let input = parse_macro_input!(item as DeriveInput);
3406 let struct_name = &input.ident;
3407 let vis = &input.vis;
3408
3409 let expertise = agent_attrs
3410 .expertise
3411 .unwrap_or_else(|| String::from("general AI assistant"));
3412 let output_type = agent_attrs
3413 .output
3414 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3415 let backend = agent_attrs
3416 .backend
3417 .unwrap_or_else(|| String::from("claude"));
3418 let model = agent_attrs.model;
3419 let profile = agent_attrs.profile;
3420
3421 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3423 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3424
3425 let found_crate =
3427 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3428 let crate_path = match found_crate {
3429 FoundCrate::Itself => {
3430 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3431 quote!(::#ident)
3432 }
3433 FoundCrate::Name(name) => {
3434 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3435 quote!(::#ident)
3436 }
3437 };
3438
3439 let inner_generic_name = agent_attrs.inner.unwrap_or_else(|| String::from("A"));
3441 let inner_generic_ident = syn::Ident::new(&inner_generic_name, proc_macro2::Span::call_site());
3442
3443 let default_agent_type = if let Some(ref custom_type) = agent_attrs.default_inner {
3445 let type_path: syn::Type =
3447 syn::parse_str(custom_type).expect("default_inner must be a valid type path");
3448 quote! { #type_path }
3449 } else {
3450 match backend.as_str() {
3452 "gemini" => quote! { #crate_path::agent::impls::GeminiAgent },
3453 _ => quote! { #crate_path::agent::impls::ClaudeCodeAgent },
3454 }
3455 };
3456
3457 let struct_def = quote! {
3459 #vis struct #struct_name<#inner_generic_ident = #default_agent_type> {
3460 inner: #inner_generic_ident,
3461 }
3462 };
3463
3464 let constructors = quote! {
3466 impl<#inner_generic_ident> #struct_name<#inner_generic_ident> {
3467 pub fn new(inner: #inner_generic_ident) -> Self {
3469 Self { inner }
3470 }
3471 }
3472 };
3473
3474 let (backend_constructors, default_impl) = if agent_attrs.default_inner.is_some() {
3476 let default_impl = quote! {
3478 impl Default for #struct_name {
3479 fn default() -> Self {
3480 Self {
3481 inner: <#default_agent_type as Default>::default(),
3482 }
3483 }
3484 }
3485 };
3486 (quote! {}, default_impl)
3487 } else {
3488 let backend_constructors = generate_backend_constructors(
3490 struct_name,
3491 &backend,
3492 model.as_deref(),
3493 profile.as_deref(),
3494 &crate_path,
3495 );
3496 let default_impl = generate_default_impl(
3497 struct_name,
3498 &backend,
3499 model.as_deref(),
3500 profile.as_deref(),
3501 &crate_path,
3502 );
3503 (backend_constructors, default_impl)
3504 };
3505
3506 let enhanced_expertise = if is_string_output {
3508 quote! { #expertise }
3510 } else {
3511 let type_name = quote!(#output_type).to_string();
3513 quote! {
3514 {
3515 use std::sync::OnceLock;
3516 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3517
3518 EXPERTISE_CACHE.get_or_init(|| {
3519 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3521
3522 if schema.is_empty() {
3523 format!(
3525 concat!(
3526 #expertise,
3527 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3528 "Do not include any text outside the JSON object."
3529 ),
3530 #type_name
3531 )
3532 } else {
3533 format!(
3535 concat!(
3536 #expertise,
3537 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3538 ),
3539 schema
3540 )
3541 }
3542 }).as_str()
3543 }
3544 }
3545 };
3546
3547 let agent_impl = quote! {
3549 #[async_trait::async_trait]
3550 impl<#inner_generic_ident> #crate_path::agent::Agent for #struct_name<#inner_generic_ident>
3551 where
3552 #inner_generic_ident: #crate_path::agent::Agent<Output = String>,
3553 {
3554 type Output = #output_type;
3555
3556 fn expertise(&self) -> &str {
3557 #enhanced_expertise
3558 }
3559
3560 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3561 let enhanced_payload = intent.prepend_text(self.expertise());
3563
3564 let response = self.inner.execute(enhanced_payload).await?;
3566
3567 let json_str = #crate_path::extract_json(&response)
3569 .map_err(|e| #crate_path::agent::AgentError::ParseError {
3570 message: e.to_string(),
3571 reason: #crate_path::agent::error::ParseErrorReason::MarkdownExtractionFailed,
3572 })?;
3573
3574 serde_json::from_str(&json_str).map_err(|e| {
3576 let reason = if e.is_eof() {
3577 #crate_path::agent::error::ParseErrorReason::UnexpectedEof
3578 } else if e.is_syntax() {
3579 #crate_path::agent::error::ParseErrorReason::InvalidJson
3580 } else {
3581 #crate_path::agent::error::ParseErrorReason::SchemaMismatch
3582 };
3583 #crate_path::agent::AgentError::ParseError {
3584 message: e.to_string(),
3585 reason,
3586 }
3587 })
3588 }
3589
3590 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3591 self.inner.is_available().await
3592 }
3593 }
3594 };
3595
3596 let expanded = quote! {
3597 #struct_def
3598 #constructors
3599 #backend_constructors
3600 #default_impl
3601 #agent_impl
3602 };
3603
3604 TokenStream::from(expanded)
3605}
3606
3607#[proc_macro_derive(TypeMarker)]
3629pub fn derive_type_marker(input: TokenStream) -> TokenStream {
3630 let input = parse_macro_input!(input as DeriveInput);
3631 let struct_name = &input.ident;
3632 let type_name_str = struct_name.to_string();
3633
3634 let found_crate =
3636 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3637 let crate_path = match found_crate {
3638 FoundCrate::Itself => {
3639 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3640 quote!(::#ident)
3641 }
3642 FoundCrate::Name(name) => {
3643 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3644 quote!(::#ident)
3645 }
3646 };
3647
3648 let expanded = quote! {
3649 impl #crate_path::orchestrator::TypeMarker for #struct_name {
3650 const TYPE_NAME: &'static str = #type_name_str;
3651 }
3652 };
3653
3654 TokenStream::from(expanded)
3655}
3656
3657#[proc_macro_attribute]
3693pub fn type_marker(_attr: TokenStream, item: TokenStream) -> TokenStream {
3694 let input = parse_macro_input!(item as syn::DeriveInput);
3695 let struct_name = &input.ident;
3696 let vis = &input.vis;
3697 let type_name_str = struct_name.to_string();
3698
3699 let default_fn_name = syn::Ident::new(
3701 &format!("default_{}_type", to_snake_case(&type_name_str)),
3702 struct_name.span(),
3703 );
3704
3705 let found_crate =
3707 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3708 let crate_path = match found_crate {
3709 FoundCrate::Itself => {
3710 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3711 quote!(::#ident)
3712 }
3713 FoundCrate::Name(name) => {
3714 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3715 quote!(::#ident)
3716 }
3717 };
3718
3719 let fields = match &input.data {
3721 syn::Data::Struct(data_struct) => match &data_struct.fields {
3722 syn::Fields::Named(fields) => &fields.named,
3723 _ => {
3724 return syn::Error::new_spanned(
3725 struct_name,
3726 "type_marker only works with structs with named fields",
3727 )
3728 .to_compile_error()
3729 .into();
3730 }
3731 },
3732 _ => {
3733 return syn::Error::new_spanned(struct_name, "type_marker only works with structs")
3734 .to_compile_error()
3735 .into();
3736 }
3737 };
3738
3739 let mut new_fields = vec![];
3741
3742 let default_fn_name_str = default_fn_name.to_string();
3744 let default_fn_name_lit = syn::LitStr::new(&default_fn_name_str, default_fn_name.span());
3745
3746 new_fields.push(quote! {
3751 #[serde(default = #default_fn_name_lit)]
3752 __type: String
3753 });
3754
3755 for field in fields {
3757 new_fields.push(quote! { #field });
3758 }
3759
3760 let attrs = &input.attrs;
3762 let generics = &input.generics;
3763
3764 let expanded = quote! {
3765 fn #default_fn_name() -> String {
3767 #type_name_str.to_string()
3768 }
3769
3770 #(#attrs)*
3772 #vis struct #struct_name #generics {
3773 #(#new_fields),*
3774 }
3775
3776 impl #crate_path::orchestrator::TypeMarker for #struct_name {
3778 const TYPE_NAME: &'static str = #type_name_str;
3779 }
3780 };
3781
3782 TokenStream::from(expanded)
3783}