1use proc_macro::TokenStream;
2use proc_macro_crate::{FoundCrate, crate_name};
3use quote::quote;
4use regex::Regex;
5use syn::{
6 Data, DeriveInput, Meta, Token,
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 punctuated::Punctuated,
10};
11
12fn parse_template_placeholders_with_mode(template: &str) -> Vec<(String, Option<String>)> {
15 let mut placeholders = Vec::new();
16 let mut seen_fields = std::collections::HashSet::new();
17
18 let mode_pattern = Regex::new(r"\{\{\s*(\w+)\s*:\s*(\w+)\s*\}\}").unwrap();
20 for cap in mode_pattern.captures_iter(template) {
21 let field_name = cap[1].to_string();
22 let mode = cap[2].to_string();
23 placeholders.push((field_name.clone(), Some(mode)));
24 seen_fields.insert(field_name);
25 }
26
27 let standard_pattern = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
29 for cap in standard_pattern.captures_iter(template) {
30 let field_name = cap[1].to_string();
31 if !seen_fields.contains(&field_name) {
33 placeholders.push((field_name, None));
34 }
35 }
36
37 placeholders
38}
39
40fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
42 attrs
43 .iter()
44 .filter_map(|attr| {
45 if attr.path().is_ident("doc")
46 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
47 && let syn::Expr::Lit(syn::ExprLit {
48 lit: syn::Lit::Str(lit_str),
49 ..
50 }) = &meta_name_value.value
51 {
52 return Some(lit_str.value());
53 }
54 None
55 })
56 .map(|s| s.trim().to_string())
57 .collect::<Vec<_>>()
58 .join(" ")
59}
60
61fn generate_example_only_parts(
63 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
64 has_default: bool,
65 crate_path: &proc_macro2::TokenStream,
66) -> proc_macro2::TokenStream {
67 let mut field_values = Vec::new();
68
69 for field in fields.iter() {
70 let field_name = field.ident.as_ref().unwrap();
71 let field_name_str = field_name.to_string();
72 let attrs = parse_field_prompt_attrs(&field.attrs);
73
74 if field_name_str == "__type" {
78 continue;
79 }
80
81 if attrs.skip {
83 continue;
84 }
85
86 if let Some(example) = attrs.example {
88 field_values.push(quote! {
90 json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
91 });
92 } else if has_default {
93 field_values.push(quote! {
95 let default_value = serde_json::to_value(&default_instance.#field_name)
96 .unwrap_or(serde_json::Value::Null);
97 json_obj.insert(#field_name_str.to_string(), default_value);
98 });
99 } else {
100 field_values.push(quote! {
102 let value = serde_json::to_value(&self.#field_name)
103 .unwrap_or(serde_json::Value::Null);
104 json_obj.insert(#field_name_str.to_string(), value);
105 });
106 }
107 }
108
109 if has_default {
110 quote! {
111 {
112 let default_instance = Self::default();
113 let mut json_obj = serde_json::Map::new();
114 #(#field_values)*
115 let json_value = serde_json::Value::Object(json_obj);
116 let json_str = serde_json::to_string_pretty(&json_value)
117 .unwrap_or_else(|_| "{}".to_string());
118 vec![#crate_path::prompt::PromptPart::Text(json_str)]
119 }
120 }
121 } else {
122 quote! {
123 {
124 let mut json_obj = serde_json::Map::new();
125 #(#field_values)*
126 let json_value = serde_json::Value::Object(json_obj);
127 let json_str = serde_json::to_string_pretty(&json_value)
128 .unwrap_or_else(|_| "{}".to_string());
129 vec![#crate_path::prompt::PromptPart::Text(json_str)]
130 }
131 }
132 }
133}
134
135fn generate_schema_only_parts(
137 struct_name: &str,
138 struct_docs: &str,
139 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
140 crate_path: &proc_macro2::TokenStream,
141 _has_type_marker: bool,
142) -> proc_macro2::TokenStream {
143 let mut field_schema_parts = vec![];
144 let mut nested_type_collectors = vec![];
145
146 for field in fields.iter() {
148 let field_name = field.ident.as_ref().unwrap();
149 let field_name_str = field_name.to_string();
150 let attrs = parse_field_prompt_attrs(&field.attrs);
151
152 if field_name_str == "__type" {
156 continue;
157 }
158
159 if attrs.skip {
161 continue;
162 }
163
164 let field_docs = extract_doc_comments(&field.attrs);
166
167 let (is_vec, inner_type) = extract_vec_inner_type(&field.ty);
169
170 if is_vec {
171 let comment = if !field_docs.is_empty() {
174 format!(" // {}", field_docs)
175 } else {
176 String::new()
177 };
178
179 field_schema_parts.push(quote! {
180 {
181 let type_name = stringify!(#inner_type);
182 format!(" {}: {}[];{}", #field_name_str, type_name, #comment)
183 }
184 });
185
186 if let Some(inner) = inner_type
188 && !is_primitive_type(inner)
189 {
190 nested_type_collectors.push(quote! {
191 <#inner as #crate_path::prompt::ToPrompt>::prompt_schema()
192 });
193 }
194 } else {
195 let field_type = &field.ty;
197 let is_primitive = is_primitive_type(field_type);
198
199 if !is_primitive {
200 let comment = if !field_docs.is_empty() {
203 format!(" // {}", field_docs)
204 } else {
205 String::new()
206 };
207
208 field_schema_parts.push(quote! {
209 {
210 let type_name = stringify!(#field_type);
211 format!(" {}: {};{}", #field_name_str, type_name, #comment)
212 }
213 });
214
215 nested_type_collectors.push(quote! {
217 <#field_type as #crate_path::prompt::ToPrompt>::prompt_schema()
218 });
219 } else {
220 let type_str = format_type_for_schema(&field.ty);
223 let comment = if !field_docs.is_empty() {
224 format!(" // {}", field_docs)
225 } else {
226 String::new()
227 };
228
229 field_schema_parts.push(quote! {
230 format!(" {}: {};{}", #field_name_str, #type_str, #comment)
231 });
232 }
233 }
234 }
235
236 let mut header_lines = Vec::new();
251
252 if !struct_docs.is_empty() {
254 header_lines.push("/**".to_string());
255 header_lines.push(format!(" * {}", struct_docs));
256 header_lines.push(" */".to_string());
257 }
258
259 header_lines.push(format!("type {} = {{", struct_name));
261
262 quote! {
263 {
264 let mut all_lines: Vec<String> = Vec::new();
265
266 let nested_schemas: Vec<String> = vec![#(#nested_type_collectors),*];
268 let mut seen_types = std::collections::HashSet::<String>::new();
269
270 for schema in nested_schemas {
271 if !schema.is_empty() {
272 if seen_types.insert(schema.clone()) {
274 all_lines.push(schema);
275 all_lines.push(String::new()); }
277 }
278 }
279
280 let mut lines: Vec<String> = Vec::new();
282 #(lines.push(#header_lines.to_string());)*
283 #(lines.push(#field_schema_parts);)*
284 lines.push("}".to_string());
285 all_lines.push(lines.join("\n"));
286
287 vec![#crate_path::prompt::PromptPart::Text(all_lines.join("\n"))]
288 }
289 }
290}
291
292fn extract_vec_inner_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
294 if let syn::Type::Path(type_path) = ty
295 && let Some(last_segment) = type_path.path.segments.last()
296 && last_segment.ident == "Vec"
297 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
298 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
299 {
300 return (true, Some(inner_type));
301 }
302 (false, None)
303}
304
305fn is_primitive_type(ty: &syn::Type) -> bool {
307 if let syn::Type::Path(type_path) = ty
308 && let Some(last_segment) = type_path.path.segments.last()
309 {
310 let type_name = last_segment.ident.to_string();
311 matches!(
312 type_name.as_str(),
313 "String"
314 | "str"
315 | "i8"
316 | "i16"
317 | "i32"
318 | "i64"
319 | "i128"
320 | "isize"
321 | "u8"
322 | "u16"
323 | "u32"
324 | "u64"
325 | "u128"
326 | "usize"
327 | "f32"
328 | "f64"
329 | "bool"
330 | "Vec"
331 | "Option"
332 | "HashMap"
333 | "BTreeMap"
334 | "HashSet"
335 | "BTreeSet"
336 )
337 } else {
338 true
340 }
341}
342
343fn format_type_for_schema(ty: &syn::Type) -> String {
345 match ty {
347 syn::Type::Path(type_path) => {
348 let path = &type_path.path;
349 if let Some(last_segment) = path.segments.last() {
350 let type_name = last_segment.ident.to_string();
351
352 if type_name == "Option"
354 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
355 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
356 {
357 return format!("{} | null", format_type_for_schema(inner_type));
358 }
359
360 match type_name.as_str() {
362 "String" | "str" => "string".to_string(),
363 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
364 | "u64" | "u128" | "usize" => "number".to_string(),
365 "f32" | "f64" => "number".to_string(),
366 "bool" => "boolean".to_string(),
367 "Vec" => {
368 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
369 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
370 {
371 return format!("{}[]", format_type_for_schema(inner_type));
372 }
373 "array".to_string()
374 }
375 _ => type_name.to_lowercase(),
376 }
377 } else {
378 "unknown".to_string()
379 }
380 }
381 _ => "unknown".to_string(),
382 }
383}
384
385enum PromptAttribute {
387 Skip,
388 Description(String),
389 None,
390}
391
392fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
394 for attr in attrs {
395 if attr.path().is_ident("prompt") {
396 if let Ok(meta_list) = attr.meta.require_list() {
398 let tokens = &meta_list.tokens;
399 let tokens_str = tokens.to_string();
400 if tokens_str == "skip" {
401 return PromptAttribute::Skip;
402 }
403 }
404
405 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
407 return PromptAttribute::Description(lit_str.value());
408 }
409 }
410 }
411 PromptAttribute::None
412}
413
414#[derive(Debug, Clone, Copy, PartialEq, Eq)]
416enum RenameRule {
417 #[allow(dead_code)]
418 None,
419 LowerCase,
420 UpperCase,
421 PascalCase,
422 CamelCase,
423 SnakeCase,
424 ScreamingSnakeCase,
425 KebabCase,
426 ScreamingKebabCase,
427}
428
429impl RenameRule {
430 fn from_str(s: &str) -> Option<Self> {
432 match s {
433 "lowercase" => Some(Self::LowerCase),
434 "UPPERCASE" => Some(Self::UpperCase),
435 "PascalCase" => Some(Self::PascalCase),
436 "camelCase" => Some(Self::CamelCase),
437 "snake_case" => Some(Self::SnakeCase),
438 "SCREAMING_SNAKE_CASE" => Some(Self::ScreamingSnakeCase),
439 "kebab-case" => Some(Self::KebabCase),
440 "SCREAMING-KEBAB-CASE" => Some(Self::ScreamingKebabCase),
441 _ => None,
442 }
443 }
444
445 fn apply(&self, name: &str) -> String {
447 match self {
448 Self::None => name.to_string(),
449 Self::LowerCase => name.to_lowercase(),
450 Self::UpperCase => name.to_uppercase(),
451 Self::PascalCase => name.to_string(), Self::CamelCase => {
453 let mut chars = name.chars();
455 match chars.next() {
456 None => String::new(),
457 Some(first) => first.to_lowercase().chain(chars).collect(),
458 }
459 }
460 Self::SnakeCase => {
461 let mut result = String::new();
463 for (i, ch) in name.chars().enumerate() {
464 if ch.is_uppercase() && i > 0 {
465 result.push('_');
466 }
467 result.push(ch.to_lowercase().next().unwrap());
468 }
469 result
470 }
471 Self::ScreamingSnakeCase => {
472 let mut result = String::new();
474 for (i, ch) in name.chars().enumerate() {
475 if ch.is_uppercase() && i > 0 {
476 result.push('_');
477 }
478 result.push(ch.to_uppercase().next().unwrap());
479 }
480 result
481 }
482 Self::KebabCase => {
483 let mut result = String::new();
485 for (i, ch) in name.chars().enumerate() {
486 if ch.is_uppercase() && i > 0 {
487 result.push('-');
488 }
489 result.push(ch.to_lowercase().next().unwrap());
490 }
491 result
492 }
493 Self::ScreamingKebabCase => {
494 let mut result = String::new();
496 for (i, ch) in name.chars().enumerate() {
497 if ch.is_uppercase() && i > 0 {
498 result.push('-');
499 }
500 result.push(ch.to_uppercase().next().unwrap());
501 }
502 result
503 }
504 }
505 }
506}
507
508fn parse_serde_rename_all(attrs: &[syn::Attribute]) -> Option<RenameRule> {
510 for attr in attrs {
511 if attr.path().is_ident("serde")
512 && let Ok(meta_list) = attr.meta.require_list()
513 {
514 if let Ok(metas) =
516 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
517 {
518 for meta in metas {
519 if let Meta::NameValue(nv) = meta
520 && nv.path.is_ident("rename_all")
521 && let syn::Expr::Lit(syn::ExprLit {
522 lit: syn::Lit::Str(lit_str),
523 ..
524 }) = nv.value
525 {
526 return RenameRule::from_str(&lit_str.value());
527 }
528 }
529 }
530 }
531 }
532 None
533}
534
535#[derive(Debug, Default)]
537struct FieldPromptAttrs {
538 skip: bool,
539 rename: Option<String>,
540 format_with: Option<String>,
541 image: bool,
542 example: Option<String>,
543}
544
545fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
547 let mut result = FieldPromptAttrs::default();
548
549 for attr in attrs {
550 if attr.path().is_ident("prompt") {
551 if let Ok(meta_list) = attr.meta.require_list() {
553 if let Ok(metas) =
555 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
556 {
557 for meta in metas {
558 match meta {
559 Meta::Path(path) if path.is_ident("skip") => {
560 result.skip = true;
561 }
562 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
563 if let syn::Expr::Lit(syn::ExprLit {
564 lit: syn::Lit::Str(lit_str),
565 ..
566 }) = nv.value
567 {
568 result.rename = Some(lit_str.value());
569 }
570 }
571 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
572 if let syn::Expr::Lit(syn::ExprLit {
573 lit: syn::Lit::Str(lit_str),
574 ..
575 }) = nv.value
576 {
577 result.format_with = Some(lit_str.value());
578 }
579 }
580 Meta::Path(path) if path.is_ident("image") => {
581 result.image = true;
582 }
583 Meta::NameValue(nv) if nv.path.is_ident("example") => {
584 if let syn::Expr::Lit(syn::ExprLit {
585 lit: syn::Lit::Str(lit_str),
586 ..
587 }) = nv.value
588 {
589 result.example = Some(lit_str.value());
590 }
591 }
592 _ => {}
593 }
594 }
595 } else if meta_list.tokens.to_string() == "skip" {
596 result.skip = true;
598 } else if meta_list.tokens.to_string() == "image" {
599 result.image = true;
601 }
602 }
603 }
604 }
605
606 result
607}
608
609#[proc_macro_derive(ToPrompt, attributes(prompt))]
652pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
653 let input = parse_macro_input!(input as DeriveInput);
654
655 let found_crate =
656 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
657 let crate_path = match found_crate {
658 FoundCrate::Itself => {
659 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
661 quote!(::#ident)
662 }
663 FoundCrate::Name(name) => {
664 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
665 quote!(::#ident)
666 }
667 };
668
669 match &input.data {
671 Data::Enum(data_enum) => {
672 let enum_name = &input.ident;
674 let enum_docs = extract_doc_comments(&input.attrs);
675
676 let rename_rule = parse_serde_rename_all(&input.attrs);
678
679 let mut variant_lines = Vec::new();
692 let mut first_variant_name = None;
693
694 for variant in &data_enum.variants {
695 let variant_name = &variant.ident;
696 let variant_name_str = variant_name.to_string();
697
698 let variant_value = if let Some(rule) = rename_rule {
700 rule.apply(&variant_name_str)
701 } else {
702 variant_name_str.clone()
703 };
704
705 match parse_prompt_attribute(&variant.attrs) {
706 PromptAttribute::Skip => continue,
707 PromptAttribute::Description(desc) => {
708 variant_lines.push(format!(" | \"{}\" // {}", variant_value, desc));
709 if first_variant_name.is_none() {
710 first_variant_name = Some(variant_value.clone());
711 }
712 }
713 PromptAttribute::None => {
714 let docs = extract_doc_comments(&variant.attrs);
715 if !docs.is_empty() {
716 variant_lines.push(format!(" | \"{}\" // {}", variant_value, docs));
717 } else {
718 variant_lines.push(format!(" | \"{}\"", variant_value));
719 }
720 if first_variant_name.is_none() {
721 first_variant_name = Some(variant_value.clone());
722 }
723 }
724 }
725 }
726
727 let mut lines = Vec::new();
729
730 if !enum_docs.is_empty() {
732 lines.push("/**".to_string());
733 lines.push(format!(" * {}", enum_docs));
734 lines.push(" */".to_string());
735 }
736
737 lines.push(format!("type {} =", enum_name));
739
740 for line in &variant_lines {
742 lines.push(line.clone());
743 }
744
745 if let Some(last) = lines.last_mut()
747 && !last.ends_with(';')
748 {
749 last.push(';');
750 }
751
752 if let Some(first_name) = first_variant_name {
754 lines.push("".to_string()); lines.push(format!("Example value: \"{}\"", first_name));
756 }
757
758 let prompt_string = lines.join("\n");
759 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
760
761 let mut match_arms = Vec::new();
763 for variant in &data_enum.variants {
764 let variant_name = &variant.ident;
765
766 match parse_prompt_attribute(&variant.attrs) {
768 PromptAttribute::Skip => {
769 match_arms.push(quote! {
771 Self::#variant_name => stringify!(#variant_name).to_string()
772 });
773 }
774 PromptAttribute::Description(desc) => {
775 match_arms.push(quote! {
777 Self::#variant_name => format!("{}: {}", stringify!(#variant_name), #desc)
778 });
779 }
780 PromptAttribute::None => {
781 let variant_docs = extract_doc_comments(&variant.attrs);
783 if !variant_docs.is_empty() {
784 match_arms.push(quote! {
785 Self::#variant_name => format!("{}: {}", stringify!(#variant_name), #variant_docs)
786 });
787 } else {
788 match_arms.push(quote! {
789 Self::#variant_name => stringify!(#variant_name).to_string()
790 });
791 }
792 }
793 }
794 }
795
796 let to_prompt_impl = if match_arms.is_empty() {
797 quote! {
799 fn to_prompt(&self) -> String {
800 match *self {}
801 }
802 }
803 } else {
804 quote! {
805 fn to_prompt(&self) -> String {
806 match self {
807 #(#match_arms),*
808 }
809 }
810 }
811 };
812
813 let expanded = quote! {
814 impl #impl_generics #crate_path::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
815 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
816 vec![#crate_path::prompt::PromptPart::Text(self.to_prompt())]
817 }
818
819 #to_prompt_impl
820
821 fn prompt_schema() -> String {
822 #prompt_string.to_string()
823 }
824 }
825 };
826
827 TokenStream::from(expanded)
828 }
829 Data::Struct(data_struct) => {
830 let mut template_attr = None;
832 let mut template_file_attr = None;
833 let mut mode_attr = None;
834 let mut validate_attr = false;
835 let mut type_marker_attr = false;
836
837 for attr in &input.attrs {
838 if attr.path().is_ident("prompt") {
839 if let Ok(metas) =
841 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
842 {
843 for meta in metas {
844 match meta {
845 Meta::NameValue(nv) if nv.path.is_ident("template") => {
846 if let syn::Expr::Lit(expr_lit) = nv.value
847 && let syn::Lit::Str(lit_str) = expr_lit.lit
848 {
849 template_attr = Some(lit_str.value());
850 }
851 }
852 Meta::NameValue(nv) if nv.path.is_ident("template_file") => {
853 if let syn::Expr::Lit(expr_lit) = nv.value
854 && let syn::Lit::Str(lit_str) = expr_lit.lit
855 {
856 template_file_attr = Some(lit_str.value());
857 }
858 }
859 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
860 if let syn::Expr::Lit(expr_lit) = nv.value
861 && let syn::Lit::Str(lit_str) = expr_lit.lit
862 {
863 mode_attr = Some(lit_str.value());
864 }
865 }
866 Meta::NameValue(nv) if nv.path.is_ident("validate") => {
867 if let syn::Expr::Lit(expr_lit) = nv.value
868 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
869 {
870 validate_attr = lit_bool.value();
871 }
872 }
873 Meta::NameValue(nv) if nv.path.is_ident("type_marker") => {
874 if let syn::Expr::Lit(expr_lit) = nv.value
875 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
876 {
877 type_marker_attr = lit_bool.value();
878 }
879 }
880 Meta::Path(path) if path.is_ident("type_marker") => {
881 type_marker_attr = true;
883 }
884 _ => {}
885 }
886 }
887 }
888 }
889 }
890
891 if template_attr.is_some() && template_file_attr.is_some() {
893 return syn::Error::new(
894 input.ident.span(),
895 "The `template` and `template_file` attributes are mutually exclusive. Please use only one.",
896 ).to_compile_error().into();
897 }
898
899 let template_str = if let Some(file_path) = template_file_attr {
901 let mut full_path = None;
905
906 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
908 let is_trybuild = manifest_dir.contains("target/tests/trybuild");
910
911 if !is_trybuild {
912 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
914 if candidate.exists() {
915 full_path = Some(candidate);
916 }
917 } else {
918 if let Some(target_pos) = manifest_dir.find("/target/tests/trybuild") {
924 let workspace_root = &manifest_dir[..target_pos];
925 let original_macros_dir = std::path::Path::new(workspace_root)
927 .join("crates")
928 .join("llm-toolkit-macros");
929
930 let candidate = original_macros_dir.join(&file_path);
931 if candidate.exists() {
932 full_path = Some(candidate);
933 }
934 }
935 }
936 }
937
938 if full_path.is_none() {
940 let candidate = std::path::Path::new(&file_path).to_path_buf();
941 if candidate.exists() {
942 full_path = Some(candidate);
943 }
944 }
945
946 if full_path.is_none()
949 && let Ok(current_dir) = std::env::current_dir()
950 {
951 let mut search_dir = current_dir.as_path();
952 for _ in 0..10 {
954 let macros_dir = search_dir.join("crates/llm-toolkit-macros");
956 if macros_dir.exists() {
957 let candidate = macros_dir.join(&file_path);
958 if candidate.exists() {
959 full_path = Some(candidate);
960 break;
961 }
962 }
963 let candidate = search_dir.join(&file_path);
965 if candidate.exists() {
966 full_path = Some(candidate);
967 break;
968 }
969 if let Some(parent) = search_dir.parent() {
970 search_dir = parent;
971 } else {
972 break;
973 }
974 }
975 }
976
977 if full_path.is_none() {
979 let mut error_msg = format!(
981 "Template file '{}' not found at compile time.\n\nSearched in:",
982 file_path
983 );
984
985 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
986 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
987 error_msg.push_str(&format!("\n - {}", candidate.display()));
988 }
989
990 if let Ok(current_dir) = std::env::current_dir() {
991 let candidate = current_dir.join(&file_path);
992 error_msg.push_str(&format!("\n - {}", candidate.display()));
993 }
994
995 error_msg.push_str("\n\nPlease ensure:");
996 error_msg.push_str("\n 1. The template file exists");
997 error_msg.push_str("\n 2. The path is relative to CARGO_MANIFEST_DIR");
998 error_msg.push_str("\n 3. There are no typos in the path");
999
1000 return syn::Error::new(input.ident.span(), error_msg)
1001 .to_compile_error()
1002 .into();
1003 }
1004
1005 let final_path = full_path.unwrap();
1006
1007 match std::fs::read_to_string(&final_path) {
1009 Ok(content) => Some(content),
1010 Err(e) => {
1011 return syn::Error::new(
1012 input.ident.span(),
1013 format!(
1014 "Failed to read template file '{}': {}\n\nPath resolved to: {}",
1015 file_path,
1016 e,
1017 final_path.display()
1018 ),
1019 )
1020 .to_compile_error()
1021 .into();
1022 }
1023 }
1024 } else {
1025 template_attr
1026 };
1027
1028 if validate_attr && let Some(template) = &template_str {
1030 let mut env = minijinja::Environment::new();
1032 if let Err(e) = env.add_template("validation", template) {
1033 let warning_msg =
1035 format!("Template validation warning: Invalid Jinja syntax - {}", e);
1036 let warning_ident = syn::Ident::new(
1037 "TEMPLATE_VALIDATION_WARNING",
1038 proc_macro2::Span::call_site(),
1039 );
1040 let _warning_tokens = quote! {
1041 #[deprecated(note = #warning_msg)]
1042 const #warning_ident: () = ();
1043 let _ = #warning_ident;
1044 };
1045 eprintln!("cargo:warning={}", warning_msg);
1047 }
1048
1049 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1051 &fields.named
1052 } else {
1053 panic!("Template validation is only supported for structs with named fields.");
1054 };
1055
1056 let field_names: std::collections::HashSet<String> = fields
1057 .iter()
1058 .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
1059 .collect();
1060
1061 let placeholders = parse_template_placeholders_with_mode(template);
1063
1064 for (placeholder_name, _mode) in &placeholders {
1065 if placeholder_name != "self" && !field_names.contains(placeholder_name) {
1066 let warning_msg = format!(
1067 "Template validation warning: Variable '{}' used in template but not found in struct fields",
1068 placeholder_name
1069 );
1070 eprintln!("cargo:warning={}", warning_msg);
1071 }
1072 }
1073 }
1074
1075 let name = input.ident;
1076 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1077
1078 let struct_docs = extract_doc_comments(&input.attrs);
1080
1081 let is_mode_based =
1083 mode_attr.is_some() || (template_str.is_none() && struct_docs.contains("mode"));
1084
1085 let expanded = if is_mode_based || mode_attr.is_some() {
1086 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1088 &fields.named
1089 } else {
1090 panic!(
1091 "Mode-based prompt generation is only supported for structs with named fields."
1092 );
1093 };
1094
1095 let struct_name_str = name.to_string();
1096
1097 let has_default = input.attrs.iter().any(|attr| {
1099 if attr.path().is_ident("derive")
1100 && let Ok(meta_list) = attr.meta.require_list()
1101 {
1102 let tokens_str = meta_list.tokens.to_string();
1103 tokens_str.contains("Default")
1104 } else {
1105 false
1106 }
1107 });
1108
1109 let schema_parts = generate_schema_only_parts(
1120 &struct_name_str,
1121 &struct_docs,
1122 fields,
1123 &crate_path,
1124 type_marker_attr,
1125 );
1126
1127 let example_parts = generate_example_only_parts(fields, has_default, &crate_path);
1129
1130 quote! {
1131 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1132 fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<#crate_path::prompt::PromptPart> {
1133 match mode {
1134 "schema_only" => #schema_parts,
1135 "example_only" => #example_parts,
1136 "full" | _ => {
1137 let mut parts = Vec::new();
1139
1140 let schema_parts = #schema_parts;
1142 parts.extend(schema_parts);
1143
1144 parts.push(#crate_path::prompt::PromptPart::Text("\n### Example".to_string()));
1146 parts.push(#crate_path::prompt::PromptPart::Text(
1147 format!("Here is an example of a valid `{}` object:", #struct_name_str)
1148 ));
1149
1150 let example_parts = #example_parts;
1152 parts.extend(example_parts);
1153
1154 parts
1155 }
1156 }
1157 }
1158
1159 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1160 self.to_prompt_parts_with_mode("full")
1161 }
1162
1163 fn to_prompt(&self) -> String {
1164 self.to_prompt_parts()
1165 .into_iter()
1166 .filter_map(|part| match part {
1167 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1168 _ => None,
1169 })
1170 .collect::<Vec<_>>()
1171 .join("\n")
1172 }
1173
1174 fn prompt_schema() -> String {
1175 use std::sync::OnceLock;
1176 static SCHEMA_CACHE: OnceLock<String> = OnceLock::new();
1177
1178 SCHEMA_CACHE.get_or_init(|| {
1179 let schema_parts = #schema_parts;
1180 schema_parts
1181 .into_iter()
1182 .filter_map(|part| match part {
1183 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1184 _ => None,
1185 })
1186 .collect::<Vec<_>>()
1187 .join("\n")
1188 }).clone()
1189 }
1190 }
1191 }
1192 } else if let Some(template) = template_str {
1193 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1196 &fields.named
1197 } else {
1198 panic!(
1199 "Template prompt generation is only supported for structs with named fields."
1200 );
1201 };
1202
1203 let placeholders = parse_template_placeholders_with_mode(&template);
1205 let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
1207 mode.is_some()
1208 && fields
1209 .iter()
1210 .any(|f| f.ident.as_ref().unwrap() == field_name)
1211 });
1212
1213 let mut image_field_parts = Vec::new();
1214 for f in fields.iter() {
1215 let field_name = f.ident.as_ref().unwrap();
1216 let attrs = parse_field_prompt_attrs(&f.attrs);
1217
1218 if attrs.image {
1219 image_field_parts.push(quote! {
1221 parts.extend(self.#field_name.to_prompt_parts());
1222 });
1223 }
1224 }
1225
1226 if has_mode_syntax {
1228 let mut context_fields = Vec::new();
1230 let mut modified_template = template.clone();
1231
1232 for (field_name, mode_opt) in &placeholders {
1234 if let Some(mode) = mode_opt {
1235 let unique_key = format!("{}__{}", field_name, mode);
1237
1238 let pattern = format!("{{{{ {}:{} }}}}", field_name, mode);
1240 let replacement = format!("{{{{ {} }}}}", unique_key);
1241 modified_template = modified_template.replace(&pattern, &replacement);
1242
1243 let field_ident =
1245 syn::Ident::new(field_name, proc_macro2::Span::call_site());
1246
1247 context_fields.push(quote! {
1249 context.insert(
1250 #unique_key.to_string(),
1251 minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
1252 );
1253 });
1254 }
1255 }
1256
1257 for field in fields.iter() {
1259 let field_name = field.ident.as_ref().unwrap();
1260 let field_name_str = field_name.to_string();
1261
1262 let has_mode_entry = placeholders
1264 .iter()
1265 .any(|(name, mode)| name == &field_name_str && mode.is_some());
1266
1267 if !has_mode_entry {
1268 let is_primitive = match &field.ty {
1271 syn::Type::Path(type_path) => {
1272 if let Some(segment) = type_path.path.segments.last() {
1273 let type_name = segment.ident.to_string();
1274 matches!(
1275 type_name.as_str(),
1276 "String"
1277 | "str"
1278 | "i8"
1279 | "i16"
1280 | "i32"
1281 | "i64"
1282 | "i128"
1283 | "isize"
1284 | "u8"
1285 | "u16"
1286 | "u32"
1287 | "u64"
1288 | "u128"
1289 | "usize"
1290 | "f32"
1291 | "f64"
1292 | "bool"
1293 | "char"
1294 )
1295 } else {
1296 false
1297 }
1298 }
1299 _ => false,
1300 };
1301
1302 if is_primitive {
1303 context_fields.push(quote! {
1304 context.insert(
1305 #field_name_str.to_string(),
1306 minijinja::Value::from_serialize(&self.#field_name)
1307 );
1308 });
1309 } else {
1310 context_fields.push(quote! {
1312 context.insert(
1313 #field_name_str.to_string(),
1314 minijinja::Value::from(self.#field_name.to_prompt())
1315 );
1316 });
1317 }
1318 }
1319 }
1320
1321 quote! {
1322 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1323 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1324 let mut parts = Vec::new();
1325
1326 #(#image_field_parts)*
1328
1329 let text = {
1331 let mut env = minijinja::Environment::new();
1332 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1333 panic!("Failed to parse template: {}", e)
1334 });
1335
1336 let tmpl = env.get_template("prompt").unwrap();
1337
1338 let mut context = std::collections::HashMap::new();
1339 #(#context_fields)*
1340
1341 tmpl.render(context).unwrap_or_else(|e| {
1342 format!("Failed to render prompt: {}", e)
1343 })
1344 };
1345
1346 if !text.is_empty() {
1347 parts.push(#crate_path::prompt::PromptPart::Text(text));
1348 }
1349
1350 parts
1351 }
1352
1353 fn to_prompt(&self) -> String {
1354 let mut env = minijinja::Environment::new();
1356 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1357 panic!("Failed to parse template: {}", e)
1358 });
1359
1360 let tmpl = env.get_template("prompt").unwrap();
1361
1362 let mut context = std::collections::HashMap::new();
1363 #(#context_fields)*
1364
1365 tmpl.render(context).unwrap_or_else(|e| {
1366 format!("Failed to render prompt: {}", e)
1367 })
1368 }
1369
1370 fn prompt_schema() -> String {
1371 String::new() }
1373 }
1374 }
1375 } else {
1376 quote! {
1378 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1379 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1380 let mut parts = Vec::new();
1381
1382 #(#image_field_parts)*
1384
1385 let text = #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1387 format!("Failed to render prompt: {}", e)
1388 });
1389 if !text.is_empty() {
1390 parts.push(#crate_path::prompt::PromptPart::Text(text));
1391 }
1392
1393 parts
1394 }
1395
1396 fn to_prompt(&self) -> String {
1397 #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1398 format!("Failed to render prompt: {}", e)
1399 })
1400 }
1401
1402 fn prompt_schema() -> String {
1403 String::new() }
1405 }
1406 }
1407 }
1408 } else {
1409 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1412 &fields.named
1413 } else {
1414 panic!(
1415 "Default prompt generation is only supported for structs with named fields."
1416 );
1417 };
1418
1419 let mut text_field_parts = Vec::new();
1421 let mut image_field_parts = Vec::new();
1422
1423 for f in fields.iter() {
1424 let field_name = f.ident.as_ref().unwrap();
1425 let attrs = parse_field_prompt_attrs(&f.attrs);
1426
1427 if attrs.skip {
1429 continue;
1430 }
1431
1432 if attrs.image {
1433 image_field_parts.push(quote! {
1435 parts.extend(self.#field_name.to_prompt_parts());
1436 });
1437 } else {
1438 let key = if let Some(rename) = attrs.rename {
1444 rename
1445 } else {
1446 let doc_comment = extract_doc_comments(&f.attrs);
1447 if !doc_comment.is_empty() {
1448 doc_comment
1449 } else {
1450 field_name.to_string()
1451 }
1452 };
1453
1454 let value_expr = if let Some(format_with) = attrs.format_with {
1456 let func_path: syn::Path =
1458 syn::parse_str(&format_with).unwrap_or_else(|_| {
1459 panic!("Invalid function path: {}", format_with)
1460 });
1461 quote! { #func_path(&self.#field_name) }
1462 } else {
1463 quote! { self.#field_name.to_prompt() }
1464 };
1465
1466 text_field_parts.push(quote! {
1467 text_parts.push(format!("{}: {}", #key, #value_expr));
1468 });
1469 }
1470 }
1471
1472 quote! {
1474 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1475 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1476 let mut parts = Vec::new();
1477
1478 #(#image_field_parts)*
1480
1481 let mut text_parts = Vec::new();
1483 #(#text_field_parts)*
1484
1485 if !text_parts.is_empty() {
1486 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1487 }
1488
1489 parts
1490 }
1491
1492 fn to_prompt(&self) -> String {
1493 let mut text_parts = Vec::new();
1494 #(#text_field_parts)*
1495 text_parts.join("\n")
1496 }
1497
1498 fn prompt_schema() -> String {
1499 String::new() }
1501 }
1502 }
1503 };
1504
1505 TokenStream::from(expanded)
1506 }
1507 Data::Union(_) => {
1508 panic!("`#[derive(ToPrompt)]` is not supported for unions");
1509 }
1510 }
1511}
1512
1513#[derive(Debug, Clone)]
1515struct TargetInfo {
1516 name: String,
1517 template: Option<String>,
1518 field_configs: std::collections::HashMap<String, FieldTargetConfig>,
1519}
1520
1521#[derive(Debug, Clone, Default)]
1523struct FieldTargetConfig {
1524 skip: bool,
1525 rename: Option<String>,
1526 format_with: Option<String>,
1527 image: bool,
1528 include_only: bool, }
1530
1531fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
1533 let mut configs = Vec::new();
1534
1535 for attr in attrs {
1536 if attr.path().is_ident("prompt_for")
1537 && let Ok(meta_list) = attr.meta.require_list()
1538 {
1539 if meta_list.tokens.to_string() == "skip" {
1541 let config = FieldTargetConfig {
1543 skip: true,
1544 ..Default::default()
1545 };
1546 configs.push(("*".to_string(), config));
1547 } else if let Ok(metas) =
1548 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1549 {
1550 let mut target_name = None;
1551 let mut config = FieldTargetConfig::default();
1552
1553 for meta in metas {
1554 match meta {
1555 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1556 if let syn::Expr::Lit(syn::ExprLit {
1557 lit: syn::Lit::Str(lit_str),
1558 ..
1559 }) = nv.value
1560 {
1561 target_name = Some(lit_str.value());
1562 }
1563 }
1564 Meta::Path(path) if path.is_ident("skip") => {
1565 config.skip = true;
1566 }
1567 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
1568 if let syn::Expr::Lit(syn::ExprLit {
1569 lit: syn::Lit::Str(lit_str),
1570 ..
1571 }) = nv.value
1572 {
1573 config.rename = Some(lit_str.value());
1574 }
1575 }
1576 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
1577 if let syn::Expr::Lit(syn::ExprLit {
1578 lit: syn::Lit::Str(lit_str),
1579 ..
1580 }) = nv.value
1581 {
1582 config.format_with = Some(lit_str.value());
1583 }
1584 }
1585 Meta::Path(path) if path.is_ident("image") => {
1586 config.image = true;
1587 }
1588 _ => {}
1589 }
1590 }
1591
1592 if let Some(name) = target_name {
1593 config.include_only = true;
1594 configs.push((name, config));
1595 }
1596 }
1597 }
1598 }
1599
1600 configs
1601}
1602
1603fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
1605 let mut targets = Vec::new();
1606
1607 for attr in attrs {
1608 if attr.path().is_ident("prompt_for")
1609 && let Ok(meta_list) = attr.meta.require_list()
1610 && let Ok(metas) =
1611 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1612 {
1613 let mut target_name = None;
1614 let mut template = None;
1615
1616 for meta in metas {
1617 match meta {
1618 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1619 if let syn::Expr::Lit(syn::ExprLit {
1620 lit: syn::Lit::Str(lit_str),
1621 ..
1622 }) = nv.value
1623 {
1624 target_name = Some(lit_str.value());
1625 }
1626 }
1627 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1628 if let syn::Expr::Lit(syn::ExprLit {
1629 lit: syn::Lit::Str(lit_str),
1630 ..
1631 }) = nv.value
1632 {
1633 template = Some(lit_str.value());
1634 }
1635 }
1636 _ => {}
1637 }
1638 }
1639
1640 if let Some(name) = target_name {
1641 targets.push(TargetInfo {
1642 name,
1643 template,
1644 field_configs: std::collections::HashMap::new(),
1645 });
1646 }
1647 }
1648 }
1649
1650 targets
1651}
1652
1653#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
1654pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
1655 let input = parse_macro_input!(input as DeriveInput);
1656
1657 let found_crate =
1658 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1659 let crate_path = match found_crate {
1660 FoundCrate::Itself => {
1661 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1663 quote!(::#ident)
1664 }
1665 FoundCrate::Name(name) => {
1666 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1667 quote!(::#ident)
1668 }
1669 };
1670
1671 let data_struct = match &input.data {
1673 Data::Struct(data) => data,
1674 _ => {
1675 return syn::Error::new(
1676 input.ident.span(),
1677 "`#[derive(ToPromptSet)]` is only supported for structs",
1678 )
1679 .to_compile_error()
1680 .into();
1681 }
1682 };
1683
1684 let fields = match &data_struct.fields {
1685 syn::Fields::Named(fields) => &fields.named,
1686 _ => {
1687 return syn::Error::new(
1688 input.ident.span(),
1689 "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
1690 )
1691 .to_compile_error()
1692 .into();
1693 }
1694 };
1695
1696 let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
1698
1699 for field in fields.iter() {
1701 let field_name = field.ident.as_ref().unwrap().to_string();
1702 let field_configs = parse_prompt_for_attrs(&field.attrs);
1703
1704 for (target_name, config) in field_configs {
1705 if target_name == "*" {
1706 for target in &mut targets {
1708 target
1709 .field_configs
1710 .entry(field_name.clone())
1711 .or_insert_with(FieldTargetConfig::default)
1712 .skip = config.skip;
1713 }
1714 } else {
1715 let target_exists = targets.iter().any(|t| t.name == target_name);
1717 if !target_exists {
1718 targets.push(TargetInfo {
1720 name: target_name.clone(),
1721 template: None,
1722 field_configs: std::collections::HashMap::new(),
1723 });
1724 }
1725
1726 let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
1727
1728 target.field_configs.insert(field_name.clone(), config);
1729 }
1730 }
1731 }
1732
1733 let mut match_arms = Vec::new();
1735
1736 for target in &targets {
1737 let target_name = &target.name;
1738
1739 if let Some(template_str) = &target.template {
1740 let mut image_parts = Vec::new();
1742
1743 for field in fields.iter() {
1744 let field_name = field.ident.as_ref().unwrap();
1745 let field_name_str = field_name.to_string();
1746
1747 if let Some(config) = target.field_configs.get(&field_name_str)
1748 && config.image
1749 {
1750 image_parts.push(quote! {
1751 parts.extend(self.#field_name.to_prompt_parts());
1752 });
1753 }
1754 }
1755
1756 match_arms.push(quote! {
1757 #target_name => {
1758 let mut parts = Vec::new();
1759
1760 #(#image_parts)*
1761
1762 let text = #crate_path::prompt::render_prompt(#template_str, self)
1763 .map_err(|e| #crate_path::prompt::PromptSetError::RenderFailed {
1764 target: #target_name.to_string(),
1765 source: e,
1766 })?;
1767
1768 if !text.is_empty() {
1769 parts.push(#crate_path::prompt::PromptPart::Text(text));
1770 }
1771
1772 Ok(parts)
1773 }
1774 });
1775 } else {
1776 let mut text_field_parts = Vec::new();
1778 let mut image_field_parts = Vec::new();
1779
1780 for field in fields.iter() {
1781 let field_name = field.ident.as_ref().unwrap();
1782 let field_name_str = field_name.to_string();
1783
1784 let config = target.field_configs.get(&field_name_str);
1786
1787 if let Some(cfg) = config
1789 && cfg.skip
1790 {
1791 continue;
1792 }
1793
1794 let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
1798 let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
1799 .iter()
1800 .any(|(name, _)| name != "*");
1801
1802 if has_any_target_specific_config && !is_explicitly_for_this_target {
1803 continue;
1804 }
1805
1806 if let Some(cfg) = config {
1807 if cfg.image {
1808 image_field_parts.push(quote! {
1809 parts.extend(self.#field_name.to_prompt_parts());
1810 });
1811 } else {
1812 let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
1813
1814 let value_expr = if let Some(format_with) = &cfg.format_with {
1815 match syn::parse_str::<syn::Path>(format_with) {
1817 Ok(func_path) => quote! { #func_path(&self.#field_name) },
1818 Err(_) => {
1819 let error_msg = format!(
1821 "Invalid function path in format_with: '{}'",
1822 format_with
1823 );
1824 quote! {
1825 compile_error!(#error_msg);
1826 String::new()
1827 }
1828 }
1829 }
1830 } else {
1831 quote! { self.#field_name.to_prompt() }
1832 };
1833
1834 text_field_parts.push(quote! {
1835 text_parts.push(format!("{}: {}", #key, #value_expr));
1836 });
1837 }
1838 } else {
1839 text_field_parts.push(quote! {
1841 text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
1842 });
1843 }
1844 }
1845
1846 match_arms.push(quote! {
1847 #target_name => {
1848 let mut parts = Vec::new();
1849
1850 #(#image_field_parts)*
1851
1852 let mut text_parts = Vec::new();
1853 #(#text_field_parts)*
1854
1855 if !text_parts.is_empty() {
1856 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1857 }
1858
1859 Ok(parts)
1860 }
1861 });
1862 }
1863 }
1864
1865 let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
1867
1868 match_arms.push(quote! {
1870 _ => {
1871 let available = vec![#(#target_names.to_string()),*];
1872 Err(#crate_path::prompt::PromptSetError::TargetNotFound {
1873 target: target.to_string(),
1874 available,
1875 })
1876 }
1877 });
1878
1879 let struct_name = &input.ident;
1880 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1881
1882 let expanded = quote! {
1883 impl #impl_generics #crate_path::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
1884 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<#crate_path::prompt::PromptPart>, #crate_path::prompt::PromptSetError> {
1885 match target {
1886 #(#match_arms)*
1887 }
1888 }
1889 }
1890 };
1891
1892 TokenStream::from(expanded)
1893}
1894
1895struct TypeList {
1897 types: Punctuated<syn::Type, Token![,]>,
1898}
1899
1900impl Parse for TypeList {
1901 fn parse(input: ParseStream) -> syn::Result<Self> {
1902 Ok(TypeList {
1903 types: Punctuated::parse_terminated(input)?,
1904 })
1905 }
1906}
1907
1908#[proc_macro]
1932pub fn examples_section(input: TokenStream) -> TokenStream {
1933 let input = parse_macro_input!(input as TypeList);
1934
1935 let found_crate =
1936 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1937 let _crate_path = match found_crate {
1938 FoundCrate::Itself => quote!(crate),
1939 FoundCrate::Name(name) => {
1940 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1941 quote!(::#ident)
1942 }
1943 };
1944
1945 let mut type_sections = Vec::new();
1947
1948 for ty in input.types.iter() {
1949 let type_name_str = quote!(#ty).to_string();
1951
1952 type_sections.push(quote! {
1954 {
1955 let type_name = #type_name_str;
1956 let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
1957 format!("---\n#### `{}`\n{}", type_name, json_example)
1958 }
1959 });
1960 }
1961
1962 let expanded = quote! {
1964 {
1965 let mut sections = Vec::new();
1966 sections.push("---".to_string());
1967 sections.push("### Examples".to_string());
1968 sections.push("".to_string());
1969 sections.push("Here are examples of the data structures you should use.".to_string());
1970 sections.push("".to_string());
1971
1972 #(sections.push(#type_sections);)*
1973
1974 sections.push("---".to_string());
1975
1976 sections.join("\n")
1977 }
1978 };
1979
1980 TokenStream::from(expanded)
1981}
1982
1983fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
1985 for attr in attrs {
1986 if attr.path().is_ident("prompt_for")
1987 && let Ok(meta_list) = attr.meta.require_list()
1988 && let Ok(metas) =
1989 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1990 {
1991 let mut target_type = None;
1992 let mut template = None;
1993
1994 for meta in metas {
1995 match meta {
1996 Meta::NameValue(nv) if nv.path.is_ident("target") => {
1997 if let syn::Expr::Lit(syn::ExprLit {
1998 lit: syn::Lit::Str(lit_str),
1999 ..
2000 }) = nv.value
2001 {
2002 target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
2004 }
2005 }
2006 Meta::NameValue(nv) if nv.path.is_ident("template") => {
2007 if let syn::Expr::Lit(syn::ExprLit {
2008 lit: syn::Lit::Str(lit_str),
2009 ..
2010 }) = nv.value
2011 {
2012 template = Some(lit_str.value());
2013 }
2014 }
2015 _ => {}
2016 }
2017 }
2018
2019 if let (Some(target), Some(tmpl)) = (target_type, template) {
2020 return (target, tmpl);
2021 }
2022 }
2023 }
2024
2025 panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
2026}
2027
2028#[proc_macro_attribute]
2062pub fn define_intent(_attr: TokenStream, item: TokenStream) -> TokenStream {
2063 let input = parse_macro_input!(item as DeriveInput);
2064
2065 let found_crate =
2066 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2067 let crate_path = match found_crate {
2068 FoundCrate::Itself => {
2069 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2071 quote!(::#ident)
2072 }
2073 FoundCrate::Name(name) => {
2074 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2075 quote!(::#ident)
2076 }
2077 };
2078
2079 let enum_data = match &input.data {
2081 Data::Enum(data) => data,
2082 _ => {
2083 return syn::Error::new(
2084 input.ident.span(),
2085 "`#[define_intent]` can only be applied to enums",
2086 )
2087 .to_compile_error()
2088 .into();
2089 }
2090 };
2091
2092 let mut prompt_template = None;
2094 let mut extractor_tag = None;
2095 let mut mode = None;
2096
2097 for attr in &input.attrs {
2098 if attr.path().is_ident("intent")
2099 && let Ok(metas) =
2100 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2101 {
2102 for meta in metas {
2103 match meta {
2104 Meta::NameValue(nv) if nv.path.is_ident("prompt") => {
2105 if let syn::Expr::Lit(syn::ExprLit {
2106 lit: syn::Lit::Str(lit_str),
2107 ..
2108 }) = nv.value
2109 {
2110 prompt_template = Some(lit_str.value());
2111 }
2112 }
2113 Meta::NameValue(nv) if nv.path.is_ident("extractor_tag") => {
2114 if let syn::Expr::Lit(syn::ExprLit {
2115 lit: syn::Lit::Str(lit_str),
2116 ..
2117 }) = nv.value
2118 {
2119 extractor_tag = Some(lit_str.value());
2120 }
2121 }
2122 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
2123 if let syn::Expr::Lit(syn::ExprLit {
2124 lit: syn::Lit::Str(lit_str),
2125 ..
2126 }) = nv.value
2127 {
2128 mode = Some(lit_str.value());
2129 }
2130 }
2131 _ => {}
2132 }
2133 }
2134 }
2135 }
2136
2137 let mode = mode.unwrap_or_else(|| "single".to_string());
2139
2140 if mode != "single" && mode != "multi_tag" {
2142 return syn::Error::new(
2143 input.ident.span(),
2144 "`mode` must be either \"single\" or \"multi_tag\"",
2145 )
2146 .to_compile_error()
2147 .into();
2148 }
2149
2150 let prompt_template = match prompt_template {
2152 Some(p) => p,
2153 None => {
2154 return syn::Error::new(
2155 input.ident.span(),
2156 "`#[intent(...)]` attribute must include `prompt = \"...\"`",
2157 )
2158 .to_compile_error()
2159 .into();
2160 }
2161 };
2162
2163 if mode == "multi_tag" {
2165 let enum_name = &input.ident;
2166 let actions_doc = generate_multi_tag_actions_doc(&enum_data.variants);
2167 return generate_multi_tag_output(
2168 &input,
2169 enum_name,
2170 enum_data,
2171 prompt_template,
2172 actions_doc,
2173 );
2174 }
2175
2176 let extractor_tag = match extractor_tag {
2178 Some(t) => t,
2179 None => {
2180 return syn::Error::new(
2181 input.ident.span(),
2182 "`#[intent(...)]` attribute must include `extractor_tag = \"...\"`",
2183 )
2184 .to_compile_error()
2185 .into();
2186 }
2187 };
2188
2189 let enum_name = &input.ident;
2191 let enum_docs = extract_doc_comments(&input.attrs);
2192
2193 let mut intents_doc_lines = Vec::new();
2194
2195 if !enum_docs.is_empty() {
2197 intents_doc_lines.push(format!("{}: {}", enum_name, enum_docs));
2198 } else {
2199 intents_doc_lines.push(format!("{}:", enum_name));
2200 }
2201 intents_doc_lines.push(String::new()); intents_doc_lines.push("Possible values:".to_string());
2203
2204 for variant in &enum_data.variants {
2206 let variant_name = &variant.ident;
2207 let variant_docs = extract_doc_comments(&variant.attrs);
2208
2209 if !variant_docs.is_empty() {
2210 intents_doc_lines.push(format!("- {}: {}", variant_name, variant_docs));
2211 } else {
2212 intents_doc_lines.push(format!("- {}", variant_name));
2213 }
2214 }
2215
2216 let intents_doc_str = intents_doc_lines.join("\n");
2217
2218 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2220 let user_variables: Vec<String> = placeholders
2221 .iter()
2222 .filter_map(|(name, _)| {
2223 if name != "intents_doc" {
2224 Some(name.clone())
2225 } else {
2226 None
2227 }
2228 })
2229 .collect();
2230
2231 let enum_name_str = enum_name.to_string();
2233 let snake_case_name = to_snake_case(&enum_name_str);
2234 let function_name = syn::Ident::new(
2235 &format!("build_{}_prompt", snake_case_name),
2236 proc_macro2::Span::call_site(),
2237 );
2238
2239 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2241 .iter()
2242 .map(|var| {
2243 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2244 quote! { #ident: &str }
2245 })
2246 .collect();
2247
2248 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2250 .iter()
2251 .map(|var| {
2252 let var_str = var.clone();
2253 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2254 quote! {
2255 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2256 }
2257 })
2258 .collect();
2259
2260 let converted_template = prompt_template.clone();
2262
2263 let extractor_name = syn::Ident::new(
2265 &format!("{}Extractor", enum_name),
2266 proc_macro2::Span::call_site(),
2267 );
2268
2269 let filtered_attrs: Vec<_> = input
2271 .attrs
2272 .iter()
2273 .filter(|attr| !attr.path().is_ident("intent"))
2274 .collect();
2275
2276 let vis = &input.vis;
2278 let generics = &input.generics;
2279 let variants = &enum_data.variants;
2280 let enum_output = quote! {
2281 #(#filtered_attrs)*
2282 #vis enum #enum_name #generics {
2283 #variants
2284 }
2285 };
2286
2287 let expanded = quote! {
2289 #enum_output
2291
2292 pub fn #function_name(#(#function_params),*) -> String {
2294 let mut env = minijinja::Environment::new();
2295 env.add_template("prompt", #converted_template)
2296 .expect("Failed to parse intent prompt template");
2297
2298 let tmpl = env.get_template("prompt").unwrap();
2299
2300 let mut __template_context = std::collections::HashMap::new();
2301
2302 __template_context.insert("intents_doc".to_string(), minijinja::Value::from(#intents_doc_str));
2304
2305 #(#context_insertions)*
2307
2308 tmpl.render(&__template_context)
2309 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2310 }
2311
2312 pub struct #extractor_name;
2314
2315 impl #extractor_name {
2316 pub const EXTRACTOR_TAG: &'static str = #extractor_tag;
2317 }
2318
2319 impl #crate_path::intent::IntentExtractor<#enum_name> for #extractor_name {
2320 fn extract_intent(&self, response: &str) -> Result<#enum_name, #crate_path::intent::IntentExtractionError> {
2321 #crate_path::intent::extract_intent_from_response(response, Self::EXTRACTOR_TAG)
2323 }
2324 }
2325 };
2326
2327 TokenStream::from(expanded)
2328}
2329
2330fn to_snake_case(s: &str) -> String {
2332 let mut result = String::new();
2333 let mut prev_upper = false;
2334
2335 for (i, ch) in s.chars().enumerate() {
2336 if ch.is_uppercase() {
2337 if i > 0 && !prev_upper {
2338 result.push('_');
2339 }
2340 result.push(ch.to_lowercase().next().unwrap());
2341 prev_upper = true;
2342 } else {
2343 result.push(ch);
2344 prev_upper = false;
2345 }
2346 }
2347
2348 result
2349}
2350
2351#[derive(Debug, Default)]
2353struct ActionAttrs {
2354 tag: Option<String>,
2355}
2356
2357fn parse_action_attrs(attrs: &[syn::Attribute]) -> ActionAttrs {
2358 let mut result = ActionAttrs::default();
2359
2360 for attr in attrs {
2361 if attr.path().is_ident("action")
2362 && let Ok(meta_list) = attr.meta.require_list()
2363 && let Ok(metas) =
2364 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2365 {
2366 for meta in metas {
2367 if let Meta::NameValue(nv) = meta
2368 && nv.path.is_ident("tag")
2369 && let syn::Expr::Lit(syn::ExprLit {
2370 lit: syn::Lit::Str(lit_str),
2371 ..
2372 }) = nv.value
2373 {
2374 result.tag = Some(lit_str.value());
2375 }
2376 }
2377 }
2378 }
2379
2380 result
2381}
2382
2383#[derive(Debug, Default)]
2385struct FieldActionAttrs {
2386 is_attribute: bool,
2387 is_inner_text: bool,
2388}
2389
2390fn parse_field_action_attrs(attrs: &[syn::Attribute]) -> FieldActionAttrs {
2391 let mut result = FieldActionAttrs::default();
2392
2393 for attr in attrs {
2394 if attr.path().is_ident("action")
2395 && let Ok(meta_list) = attr.meta.require_list()
2396 {
2397 let tokens_str = meta_list.tokens.to_string();
2398 if tokens_str == "attribute" {
2399 result.is_attribute = true;
2400 } else if tokens_str == "inner_text" {
2401 result.is_inner_text = true;
2402 }
2403 }
2404 }
2405
2406 result
2407}
2408
2409fn generate_multi_tag_actions_doc(
2411 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2412) -> String {
2413 let mut doc_lines = Vec::new();
2414
2415 for variant in variants {
2416 let action_attrs = parse_action_attrs(&variant.attrs);
2417
2418 if let Some(tag) = action_attrs.tag {
2419 let variant_docs = extract_doc_comments(&variant.attrs);
2420
2421 match &variant.fields {
2422 syn::Fields::Unit => {
2423 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2425 }
2426 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
2427 doc_lines.push(format!("- `<{}>...</{}>`: {}", tag, tag, variant_docs));
2429 }
2430 syn::Fields::Named(fields) => {
2431 let mut attrs_str = Vec::new();
2433 let mut has_inner_text = false;
2434
2435 for field in &fields.named {
2436 let field_name = field.ident.as_ref().unwrap();
2437 let field_attrs = parse_field_action_attrs(&field.attrs);
2438
2439 if field_attrs.is_attribute {
2440 attrs_str.push(format!("{}=\"...\"", field_name));
2441 } else if field_attrs.is_inner_text {
2442 has_inner_text = true;
2443 }
2444 }
2445
2446 let attrs_part = if !attrs_str.is_empty() {
2447 format!(" {}", attrs_str.join(" "))
2448 } else {
2449 String::new()
2450 };
2451
2452 if has_inner_text {
2453 doc_lines.push(format!(
2454 "- `<{}{}>...</{}>`: {}",
2455 tag, attrs_part, tag, variant_docs
2456 ));
2457 } else if !attrs_str.is_empty() {
2458 doc_lines.push(format!("- `<{}{} />`: {}", tag, attrs_part, variant_docs));
2459 } else {
2460 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2461 }
2462
2463 for field in &fields.named {
2465 let field_name = field.ident.as_ref().unwrap();
2466 let field_attrs = parse_field_action_attrs(&field.attrs);
2467 let field_docs = extract_doc_comments(&field.attrs);
2468
2469 if field_attrs.is_attribute {
2470 doc_lines
2471 .push(format!(" - `{}` (attribute): {}", field_name, field_docs));
2472 } else if field_attrs.is_inner_text {
2473 doc_lines
2474 .push(format!(" - `{}` (inner_text): {}", field_name, field_docs));
2475 }
2476 }
2477 }
2478 _ => {
2479 }
2481 }
2482 }
2483 }
2484
2485 doc_lines.join("\n")
2486}
2487
2488fn generate_tags_regex(
2490 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2491) -> String {
2492 let mut tag_names = Vec::new();
2493
2494 for variant in variants {
2495 let action_attrs = parse_action_attrs(&variant.attrs);
2496 if let Some(tag) = action_attrs.tag {
2497 tag_names.push(tag);
2498 }
2499 }
2500
2501 if tag_names.is_empty() {
2502 return String::new();
2503 }
2504
2505 let tags_pattern = tag_names.join("|");
2506 format!(
2509 r"(?is)<(?:{})\b[^>]*/>|<(?:{})\b[^>]*>.*?</(?:{})>",
2510 tags_pattern, tags_pattern, tags_pattern
2511 )
2512}
2513
2514fn generate_multi_tag_output(
2516 input: &DeriveInput,
2517 enum_name: &syn::Ident,
2518 enum_data: &syn::DataEnum,
2519 prompt_template: String,
2520 actions_doc: String,
2521) -> TokenStream {
2522 let found_crate =
2523 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2524 let crate_path = match found_crate {
2525 FoundCrate::Itself => {
2526 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2528 quote!(::#ident)
2529 }
2530 FoundCrate::Name(name) => {
2531 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2532 quote!(::#ident)
2533 }
2534 };
2535
2536 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2538 let user_variables: Vec<String> = placeholders
2539 .iter()
2540 .filter_map(|(name, _)| {
2541 if name != "actions_doc" {
2542 Some(name.clone())
2543 } else {
2544 None
2545 }
2546 })
2547 .collect();
2548
2549 let enum_name_str = enum_name.to_string();
2551 let snake_case_name = to_snake_case(&enum_name_str);
2552 let function_name = syn::Ident::new(
2553 &format!("build_{}_prompt", snake_case_name),
2554 proc_macro2::Span::call_site(),
2555 );
2556
2557 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2559 .iter()
2560 .map(|var| {
2561 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2562 quote! { #ident: &str }
2563 })
2564 .collect();
2565
2566 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2568 .iter()
2569 .map(|var| {
2570 let var_str = var.clone();
2571 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2572 quote! {
2573 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2574 }
2575 })
2576 .collect();
2577
2578 let extractor_name = syn::Ident::new(
2580 &format!("{}Extractor", enum_name),
2581 proc_macro2::Span::call_site(),
2582 );
2583
2584 let filtered_attrs: Vec<_> = input
2586 .attrs
2587 .iter()
2588 .filter(|attr| !attr.path().is_ident("intent"))
2589 .collect();
2590
2591 let filtered_variants: Vec<proc_macro2::TokenStream> = enum_data
2593 .variants
2594 .iter()
2595 .map(|variant| {
2596 let variant_name = &variant.ident;
2597 let variant_attrs: Vec<_> = variant
2598 .attrs
2599 .iter()
2600 .filter(|attr| !attr.path().is_ident("action"))
2601 .collect();
2602 let fields = &variant.fields;
2603
2604 let filtered_fields = match fields {
2606 syn::Fields::Named(named_fields) => {
2607 let filtered: Vec<_> = named_fields
2608 .named
2609 .iter()
2610 .map(|field| {
2611 let field_name = &field.ident;
2612 let field_type = &field.ty;
2613 let field_vis = &field.vis;
2614 let filtered_attrs: Vec<_> = field
2615 .attrs
2616 .iter()
2617 .filter(|attr| !attr.path().is_ident("action"))
2618 .collect();
2619 quote! {
2620 #(#filtered_attrs)*
2621 #field_vis #field_name: #field_type
2622 }
2623 })
2624 .collect();
2625 quote! { { #(#filtered,)* } }
2626 }
2627 syn::Fields::Unnamed(unnamed_fields) => {
2628 let types: Vec<_> = unnamed_fields
2629 .unnamed
2630 .iter()
2631 .map(|field| {
2632 let field_type = &field.ty;
2633 quote! { #field_type }
2634 })
2635 .collect();
2636 quote! { (#(#types),*) }
2637 }
2638 syn::Fields::Unit => quote! {},
2639 };
2640
2641 quote! {
2642 #(#variant_attrs)*
2643 #variant_name #filtered_fields
2644 }
2645 })
2646 .collect();
2647
2648 let vis = &input.vis;
2649 let generics = &input.generics;
2650
2651 let parsing_arms = generate_parsing_arms(&enum_data.variants, enum_name);
2653
2654 let tags_regex = generate_tags_regex(&enum_data.variants);
2656
2657 let expanded = quote! {
2658 #(#filtered_attrs)*
2660 #vis enum #enum_name #generics {
2661 #(#filtered_variants),*
2662 }
2663
2664 pub fn #function_name(#(#function_params),*) -> String {
2666 let mut env = minijinja::Environment::new();
2667 env.add_template("prompt", #prompt_template)
2668 .expect("Failed to parse intent prompt template");
2669
2670 let tmpl = env.get_template("prompt").unwrap();
2671
2672 let mut __template_context = std::collections::HashMap::new();
2673
2674 __template_context.insert("actions_doc".to_string(), minijinja::Value::from(#actions_doc));
2676
2677 #(#context_insertions)*
2679
2680 tmpl.render(&__template_context)
2681 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2682 }
2683
2684 pub struct #extractor_name;
2686
2687 impl #extractor_name {
2688 fn parse_single_action(&self, text: &str) -> Option<#enum_name> {
2689 use ::quick_xml::events::Event;
2690 use ::quick_xml::Reader;
2691
2692 let mut actions = Vec::new();
2693 let mut reader = Reader::from_str(text);
2694 reader.config_mut().trim_text(true);
2695
2696 let mut buf = Vec::new();
2697
2698 loop {
2699 match reader.read_event_into(&mut buf) {
2700 Ok(Event::Start(e)) => {
2701 let owned_e = e.into_owned();
2702 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2703 let is_empty = false;
2704
2705 #parsing_arms
2706 }
2707 Ok(Event::Empty(e)) => {
2708 let owned_e = e.into_owned();
2709 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2710 let is_empty = true;
2711
2712 #parsing_arms
2713 }
2714 Ok(Event::Eof) => break,
2715 Err(_) => {
2716 break;
2718 }
2719 _ => {}
2720 }
2721 buf.clear();
2722 }
2723
2724 actions.into_iter().next()
2725 }
2726
2727 pub fn extract_actions(&self, text: &str) -> Result<Vec<#enum_name>, #crate_path::intent::IntentError> {
2728 use ::quick_xml::events::Event;
2729 use ::quick_xml::Reader;
2730
2731 let mut actions = Vec::new();
2732 let mut reader = Reader::from_str(text);
2733 reader.config_mut().trim_text(true);
2734
2735 let mut buf = Vec::new();
2736
2737 loop {
2738 match reader.read_event_into(&mut buf) {
2739 Ok(Event::Start(e)) => {
2740 let owned_e = e.into_owned();
2741 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2742 let is_empty = false;
2743
2744 #parsing_arms
2745 }
2746 Ok(Event::Empty(e)) => {
2747 let owned_e = e.into_owned();
2748 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2749 let is_empty = true;
2750
2751 #parsing_arms
2752 }
2753 Ok(Event::Eof) => break,
2754 Err(_) => {
2755 break;
2757 }
2758 _ => {}
2759 }
2760 buf.clear();
2761 }
2762
2763 Ok(actions)
2764 }
2765
2766 pub fn transform_actions<F>(&self, text: &str, mut transformer: F) -> String
2767 where
2768 F: FnMut(#enum_name) -> String,
2769 {
2770 use ::regex::Regex;
2771
2772 let regex_pattern = #tags_regex;
2773 if regex_pattern.is_empty() {
2774 return text.to_string();
2775 }
2776
2777 let re = Regex::new(®ex_pattern).unwrap_or_else(|e| {
2778 panic!("Failed to compile regex for action tags: {}", e);
2779 });
2780
2781 re.replace_all(text, |caps: &::regex::Captures| {
2782 let matched = caps.get(0).map(|m| m.as_str()).unwrap_or("");
2783
2784 if let Some(action) = self.parse_single_action(matched) {
2786 transformer(action)
2787 } else {
2788 matched.to_string()
2790 }
2791 }).to_string()
2792 }
2793
2794 pub fn strip_actions(&self, text: &str) -> String {
2795 self.transform_actions(text, |_| String::new())
2796 }
2797 }
2798 };
2799
2800 TokenStream::from(expanded)
2801}
2802
2803fn generate_parsing_arms(
2805 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2806 enum_name: &syn::Ident,
2807) -> proc_macro2::TokenStream {
2808 let mut arms = Vec::new();
2809
2810 for variant in variants {
2811 let variant_name = &variant.ident;
2812 let action_attrs = parse_action_attrs(&variant.attrs);
2813
2814 if let Some(tag) = action_attrs.tag {
2815 match &variant.fields {
2816 syn::Fields::Unit => {
2817 arms.push(quote! {
2819 if &tag_name == #tag {
2820 actions.push(#enum_name::#variant_name);
2821 }
2822 });
2823 }
2824 syn::Fields::Unnamed(_fields) => {
2825 arms.push(quote! {
2827 if &tag_name == #tag && !is_empty {
2828 match reader.read_text(owned_e.name()) {
2830 Ok(text) => {
2831 actions.push(#enum_name::#variant_name(text.to_string()));
2832 }
2833 Err(_) => {
2834 actions.push(#enum_name::#variant_name(String::new()));
2836 }
2837 }
2838 }
2839 });
2840 }
2841 syn::Fields::Named(fields) => {
2842 let mut field_names = Vec::new();
2844 let mut has_inner_text_field = None;
2845
2846 for field in &fields.named {
2847 let field_name = field.ident.as_ref().unwrap();
2848 let field_attrs = parse_field_action_attrs(&field.attrs);
2849
2850 if field_attrs.is_attribute {
2851 field_names.push(field_name.clone());
2852 } else if field_attrs.is_inner_text {
2853 has_inner_text_field = Some(field_name.clone());
2854 }
2855 }
2856
2857 if let Some(inner_text_field) = has_inner_text_field {
2858 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2861 quote! {
2862 let mut #field_name = String::new();
2863 for attr in owned_e.attributes() {
2864 if let Ok(attr) = attr {
2865 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2866 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2867 break;
2868 }
2869 }
2870 }
2871 }
2872 }).collect();
2873
2874 arms.push(quote! {
2875 if &tag_name == #tag {
2876 #(#attr_extractions)*
2877
2878 if is_empty {
2880 let #inner_text_field = String::new();
2881 actions.push(#enum_name::#variant_name {
2882 #(#field_names,)*
2883 #inner_text_field,
2884 });
2885 } else {
2886 match reader.read_text(owned_e.name()) {
2888 Ok(text) => {
2889 let #inner_text_field = text.to_string();
2890 actions.push(#enum_name::#variant_name {
2891 #(#field_names,)*
2892 #inner_text_field,
2893 });
2894 }
2895 Err(_) => {
2896 let #inner_text_field = String::new();
2898 actions.push(#enum_name::#variant_name {
2899 #(#field_names,)*
2900 #inner_text_field,
2901 });
2902 }
2903 }
2904 }
2905 }
2906 });
2907 } else {
2908 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2910 quote! {
2911 let mut #field_name = String::new();
2912 for attr in owned_e.attributes() {
2913 if let Ok(attr) = attr {
2914 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2915 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2916 break;
2917 }
2918 }
2919 }
2920 }
2921 }).collect();
2922
2923 arms.push(quote! {
2924 if &tag_name == #tag {
2925 #(#attr_extractions)*
2926 actions.push(#enum_name::#variant_name {
2927 #(#field_names),*
2928 });
2929 }
2930 });
2931 }
2932 }
2933 }
2934 }
2935 }
2936
2937 quote! {
2938 #(#arms)*
2939 }
2940}
2941
2942#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
2944pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
2945 let input = parse_macro_input!(input as DeriveInput);
2946
2947 let found_crate =
2948 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2949 let crate_path = match found_crate {
2950 FoundCrate::Itself => {
2951 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2953 quote!(::#ident)
2954 }
2955 FoundCrate::Name(name) => {
2956 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2957 quote!(::#ident)
2958 }
2959 };
2960
2961 let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
2963
2964 let struct_name = &input.ident;
2965 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2966
2967 let placeholders = parse_template_placeholders_with_mode(&template);
2969
2970 let mut converted_template = template.clone();
2972 let mut context_fields = Vec::new();
2973
2974 let fields = match &input.data {
2976 Data::Struct(data_struct) => match &data_struct.fields {
2977 syn::Fields::Named(fields) => &fields.named,
2978 _ => panic!("ToPromptFor is only supported for structs with named fields"),
2979 },
2980 _ => panic!("ToPromptFor is only supported for structs"),
2981 };
2982
2983 let has_mode_support = input.attrs.iter().any(|attr| {
2985 if attr.path().is_ident("prompt")
2986 && let Ok(metas) =
2987 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2988 {
2989 for meta in metas {
2990 if let Meta::NameValue(nv) = meta
2991 && nv.path.is_ident("mode")
2992 {
2993 return true;
2994 }
2995 }
2996 }
2997 false
2998 });
2999
3000 for (placeholder_name, mode_opt) in &placeholders {
3002 if placeholder_name == "self" {
3003 if let Some(specific_mode) = mode_opt {
3004 let unique_key = format!("self__{}", specific_mode);
3006
3007 let pattern = format!("{{{{ self:{} }}}}", specific_mode);
3009 let replacement = format!("{{{{ {} }}}}", unique_key);
3010 converted_template = converted_template.replace(&pattern, &replacement);
3011
3012 context_fields.push(quote! {
3014 context.insert(
3015 #unique_key.to_string(),
3016 minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
3017 );
3018 });
3019 } else {
3020 if has_mode_support {
3023 context_fields.push(quote! {
3025 context.insert(
3026 "self".to_string(),
3027 minijinja::Value::from(self.to_prompt_with_mode(mode))
3028 );
3029 });
3030 } else {
3031 context_fields.push(quote! {
3033 context.insert(
3034 "self".to_string(),
3035 minijinja::Value::from(self.to_prompt())
3036 );
3037 });
3038 }
3039 }
3040 } else {
3041 let field_exists = fields.iter().any(|f| {
3044 f.ident
3045 .as_ref()
3046 .is_some_and(|ident| ident == placeholder_name)
3047 });
3048
3049 if field_exists {
3050 let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
3051
3052 context_fields.push(quote! {
3056 context.insert(
3057 #placeholder_name.to_string(),
3058 minijinja::Value::from_serialize(&self.#field_ident)
3059 );
3060 });
3061 }
3062 }
3064 }
3065
3066 let expanded = quote! {
3067 impl #impl_generics #crate_path::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
3068 where
3069 #target_type: serde::Serialize,
3070 {
3071 fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
3072 let mut env = minijinja::Environment::new();
3074 env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
3075 panic!("Failed to parse template: {}", e)
3076 });
3077
3078 let tmpl = env.get_template("prompt").unwrap();
3079
3080 let mut context = std::collections::HashMap::new();
3082 context.insert(
3084 "self".to_string(),
3085 minijinja::Value::from_serialize(self)
3086 );
3087 context.insert(
3089 "target".to_string(),
3090 minijinja::Value::from_serialize(target)
3091 );
3092 #(#context_fields)*
3093
3094 tmpl.render(context).unwrap_or_else(|e| {
3096 format!("Failed to render prompt: {}", e)
3097 })
3098 }
3099 }
3100 };
3101
3102 TokenStream::from(expanded)
3103}
3104
3105struct AgentAttrs {
3111 expertise: Option<String>,
3112 output: Option<syn::Type>,
3113 backend: Option<String>,
3114 model: Option<String>,
3115 inner: Option<String>,
3116 default_inner: Option<String>,
3117 max_retries: Option<u32>,
3118 profile: Option<String>,
3119}
3120
3121impl Parse for AgentAttrs {
3122 fn parse(input: ParseStream) -> syn::Result<Self> {
3123 let mut expertise = None;
3124 let mut output = None;
3125 let mut backend = None;
3126 let mut model = None;
3127 let mut inner = None;
3128 let mut default_inner = None;
3129 let mut max_retries = None;
3130 let mut profile = None;
3131
3132 let pairs = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
3133
3134 for meta in pairs {
3135 match meta {
3136 Meta::NameValue(nv) if nv.path.is_ident("expertise") => {
3137 if let syn::Expr::Lit(syn::ExprLit {
3138 lit: syn::Lit::Str(lit_str),
3139 ..
3140 }) = &nv.value
3141 {
3142 expertise = Some(lit_str.value());
3143 }
3144 }
3145 Meta::NameValue(nv) if nv.path.is_ident("output") => {
3146 if let syn::Expr::Lit(syn::ExprLit {
3147 lit: syn::Lit::Str(lit_str),
3148 ..
3149 }) = &nv.value
3150 {
3151 let ty: syn::Type = syn::parse_str(&lit_str.value())?;
3152 output = Some(ty);
3153 }
3154 }
3155 Meta::NameValue(nv) if nv.path.is_ident("backend") => {
3156 if let syn::Expr::Lit(syn::ExprLit {
3157 lit: syn::Lit::Str(lit_str),
3158 ..
3159 }) = &nv.value
3160 {
3161 backend = Some(lit_str.value());
3162 }
3163 }
3164 Meta::NameValue(nv) if nv.path.is_ident("model") => {
3165 if let syn::Expr::Lit(syn::ExprLit {
3166 lit: syn::Lit::Str(lit_str),
3167 ..
3168 }) = &nv.value
3169 {
3170 model = Some(lit_str.value());
3171 }
3172 }
3173 Meta::NameValue(nv) if nv.path.is_ident("inner") => {
3174 if let syn::Expr::Lit(syn::ExprLit {
3175 lit: syn::Lit::Str(lit_str),
3176 ..
3177 }) = &nv.value
3178 {
3179 inner = Some(lit_str.value());
3180 }
3181 }
3182 Meta::NameValue(nv) if nv.path.is_ident("default_inner") => {
3183 if let syn::Expr::Lit(syn::ExprLit {
3184 lit: syn::Lit::Str(lit_str),
3185 ..
3186 }) = &nv.value
3187 {
3188 default_inner = Some(lit_str.value());
3189 }
3190 }
3191 Meta::NameValue(nv) if nv.path.is_ident("max_retries") => {
3192 if let syn::Expr::Lit(syn::ExprLit {
3193 lit: syn::Lit::Int(lit_int),
3194 ..
3195 }) = &nv.value
3196 {
3197 max_retries = Some(lit_int.base10_parse()?);
3198 }
3199 }
3200 Meta::NameValue(nv) if nv.path.is_ident("profile") => {
3201 if let syn::Expr::Lit(syn::ExprLit {
3202 lit: syn::Lit::Str(lit_str),
3203 ..
3204 }) = &nv.value
3205 {
3206 profile = Some(lit_str.value());
3207 }
3208 }
3209 _ => {}
3210 }
3211 }
3212
3213 Ok(AgentAttrs {
3214 expertise,
3215 output,
3216 backend,
3217 model,
3218 inner,
3219 default_inner,
3220 max_retries,
3221 profile,
3222 })
3223 }
3224}
3225
3226fn parse_agent_attrs(attrs: &[syn::Attribute]) -> syn::Result<AgentAttrs> {
3228 for attr in attrs {
3229 if attr.path().is_ident("agent") {
3230 return attr.parse_args::<AgentAttrs>();
3231 }
3232 }
3233
3234 Ok(AgentAttrs {
3235 expertise: None,
3236 output: None,
3237 backend: None,
3238 model: None,
3239 inner: None,
3240 default_inner: None,
3241 max_retries: None,
3242 profile: None,
3243 })
3244}
3245
3246fn generate_backend_constructors(
3248 struct_name: &syn::Ident,
3249 backend: &str,
3250 _model: Option<&str>,
3251 _profile: Option<&str>,
3252 crate_path: &proc_macro2::TokenStream,
3253) -> proc_macro2::TokenStream {
3254 match backend {
3255 "claude" => {
3256 quote! {
3257 impl #struct_name {
3258 pub fn with_claude() -> Self {
3260 Self::new(#crate_path::agent::impls::ClaudeCodeAgent::new())
3261 }
3262
3263 pub fn with_claude_model(model: &str) -> Self {
3265 Self::new(
3266 #crate_path::agent::impls::ClaudeCodeAgent::new()
3267 .with_model_str(model)
3268 )
3269 }
3270 }
3271 }
3272 }
3273 "gemini" => {
3274 quote! {
3275 impl #struct_name {
3276 pub fn with_gemini() -> Self {
3278 Self::new(#crate_path::agent::impls::GeminiAgent::new())
3279 }
3280
3281 pub fn with_gemini_model(model: &str) -> Self {
3283 Self::new(
3284 #crate_path::agent::impls::GeminiAgent::new()
3285 .with_model_str(model)
3286 )
3287 }
3288 }
3289 }
3290 }
3291 _ => quote! {},
3292 }
3293}
3294
3295fn generate_default_impl(
3297 struct_name: &syn::Ident,
3298 backend: &str,
3299 model: Option<&str>,
3300 profile: Option<&str>,
3301 crate_path: &proc_macro2::TokenStream,
3302) -> proc_macro2::TokenStream {
3303 let profile_expr = if let Some(profile_str) = profile {
3305 match profile_str.to_lowercase().as_str() {
3306 "creative" => quote! { #crate_path::agent::ExecutionProfile::Creative },
3307 "balanced" => quote! { #crate_path::agent::ExecutionProfile::Balanced },
3308 "deterministic" => quote! { #crate_path::agent::ExecutionProfile::Deterministic },
3309 _ => quote! { #crate_path::agent::ExecutionProfile::Balanced }, }
3311 } else {
3312 quote! { #crate_path::agent::ExecutionProfile::default() }
3313 };
3314
3315 let agent_init = match backend {
3316 "gemini" => {
3317 let mut builder = quote! { #crate_path::agent::impls::GeminiAgent::new() };
3318
3319 if let Some(model_str) = model {
3320 builder = quote! { #builder.with_model_str(#model_str) };
3321 }
3322
3323 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3324 builder
3325 }
3326 _ => {
3327 let mut builder = quote! { #crate_path::agent::impls::ClaudeCodeAgent::new() };
3329
3330 if let Some(model_str) = model {
3331 builder = quote! { #builder.with_model_str(#model_str) };
3332 }
3333
3334 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3335 builder
3336 }
3337 };
3338
3339 quote! {
3340 impl Default for #struct_name {
3341 fn default() -> Self {
3342 Self::new(#agent_init)
3343 }
3344 }
3345 }
3346}
3347
3348#[proc_macro_derive(Agent, attributes(agent))]
3357pub fn derive_agent(input: TokenStream) -> TokenStream {
3358 let input = parse_macro_input!(input as DeriveInput);
3359 let struct_name = &input.ident;
3360
3361 let agent_attrs = match parse_agent_attrs(&input.attrs) {
3363 Ok(attrs) => attrs,
3364 Err(e) => return e.to_compile_error().into(),
3365 };
3366
3367 let expertise = agent_attrs
3368 .expertise
3369 .unwrap_or_else(|| String::from("general AI assistant"));
3370 let output_type = agent_attrs
3371 .output
3372 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3373 let backend = agent_attrs
3374 .backend
3375 .unwrap_or_else(|| String::from("claude"));
3376 let model = agent_attrs.model;
3377 let _profile = agent_attrs.profile; let max_retries = agent_attrs.max_retries.unwrap_or(3); let found_crate =
3382 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3383 let crate_path = match found_crate {
3384 FoundCrate::Itself => {
3385 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3387 quote!(::#ident)
3388 }
3389 FoundCrate::Name(name) => {
3390 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3391 quote!(::#ident)
3392 }
3393 };
3394
3395 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
3396
3397 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3399 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3400
3401 let enhanced_expertise = if is_string_output {
3403 quote! { #expertise }
3405 } else {
3406 let type_name = quote!(#output_type).to_string();
3408 quote! {
3409 {
3410 use std::sync::OnceLock;
3411 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3412
3413 EXPERTISE_CACHE.get_or_init(|| {
3414 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3416
3417 if schema.is_empty() {
3418 format!(
3420 concat!(
3421 #expertise,
3422 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3423 "Do not include any text outside the JSON object."
3424 ),
3425 #type_name
3426 )
3427 } else {
3428 format!(
3430 concat!(
3431 #expertise,
3432 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3433 ),
3434 schema
3435 )
3436 }
3437 }).as_str()
3438 }
3439 }
3440 };
3441
3442 let agent_init = match backend.as_str() {
3444 "gemini" => {
3445 if let Some(model_str) = model {
3446 quote! {
3447 use #crate_path::agent::impls::GeminiAgent;
3448 let agent = GeminiAgent::new().with_model_str(#model_str);
3449 }
3450 } else {
3451 quote! {
3452 use #crate_path::agent::impls::GeminiAgent;
3453 let agent = GeminiAgent::new();
3454 }
3455 }
3456 }
3457 "claude" => {
3458 if let Some(model_str) = model {
3459 quote! {
3460 use #crate_path::agent::impls::ClaudeCodeAgent;
3461 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3462 }
3463 } else {
3464 quote! {
3465 use #crate_path::agent::impls::ClaudeCodeAgent;
3466 let agent = ClaudeCodeAgent::new();
3467 }
3468 }
3469 }
3470 _ => {
3471 if let Some(model_str) = model {
3473 quote! {
3474 use #crate_path::agent::impls::ClaudeCodeAgent;
3475 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3476 }
3477 } else {
3478 quote! {
3479 use #crate_path::agent::impls::ClaudeCodeAgent;
3480 let agent = ClaudeCodeAgent::new();
3481 }
3482 }
3483 }
3484 };
3485
3486 let expanded = quote! {
3487 #[async_trait::async_trait]
3488 impl #impl_generics #crate_path::agent::Agent for #struct_name #ty_generics #where_clause {
3489 type Output = #output_type;
3490
3491 fn expertise(&self) -> &str {
3492 #enhanced_expertise
3493 }
3494
3495 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3496 #agent_init
3498
3499 let agent_ref = &agent;
3501 #crate_path::agent::retry::retry_execution(
3502 #max_retries,
3503 &intent,
3504 move |payload| {
3505 let payload = payload.clone();
3506 async move {
3507 let response = agent_ref.execute(payload).await?;
3509
3510 let json_str = #crate_path::extract_json(&response)
3512 .map_err(|e| #crate_path::agent::AgentError::ParseError {
3513 message: format!("Failed to extract JSON: {}", e),
3514 reason: #crate_path::agent::error::ParseErrorReason::MarkdownExtractionFailed,
3515 })?;
3516
3517 serde_json::from_str::<Self::Output>(&json_str)
3519 .map_err(|e| {
3520 let reason = if e.is_eof() {
3522 #crate_path::agent::error::ParseErrorReason::UnexpectedEof
3523 } else if e.is_syntax() {
3524 #crate_path::agent::error::ParseErrorReason::InvalidJson
3525 } else {
3526 #crate_path::agent::error::ParseErrorReason::SchemaMismatch
3527 };
3528
3529 #crate_path::agent::AgentError::ParseError {
3530 message: format!("Failed to parse JSON: {}", e),
3531 reason,
3532 }
3533 })
3534 }
3535 }
3536 ).await
3537 }
3538
3539 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3540 #agent_init
3542 agent.is_available().await
3543 }
3544 }
3545 };
3546
3547 TokenStream::from(expanded)
3548}
3549
3550#[proc_macro_attribute]
3565pub fn agent(attr: TokenStream, item: TokenStream) -> TokenStream {
3566 let agent_attrs = match syn::parse::<AgentAttrs>(attr) {
3568 Ok(attrs) => attrs,
3569 Err(e) => return e.to_compile_error().into(),
3570 };
3571
3572 let input = parse_macro_input!(item as DeriveInput);
3574 let struct_name = &input.ident;
3575 let vis = &input.vis;
3576
3577 let expertise = agent_attrs
3578 .expertise
3579 .unwrap_or_else(|| String::from("general AI assistant"));
3580 let output_type = agent_attrs
3581 .output
3582 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3583 let backend = agent_attrs
3584 .backend
3585 .unwrap_or_else(|| String::from("claude"));
3586 let model = agent_attrs.model;
3587 let profile = agent_attrs.profile;
3588
3589 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3591 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3592
3593 let found_crate =
3595 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3596 let crate_path = match found_crate {
3597 FoundCrate::Itself => {
3598 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3599 quote!(::#ident)
3600 }
3601 FoundCrate::Name(name) => {
3602 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3603 quote!(::#ident)
3604 }
3605 };
3606
3607 let inner_generic_name = agent_attrs.inner.unwrap_or_else(|| String::from("A"));
3609 let inner_generic_ident = syn::Ident::new(&inner_generic_name, proc_macro2::Span::call_site());
3610
3611 let default_agent_type = if let Some(ref custom_type) = agent_attrs.default_inner {
3613 let type_path: syn::Type =
3615 syn::parse_str(custom_type).expect("default_inner must be a valid type path");
3616 quote! { #type_path }
3617 } else {
3618 match backend.as_str() {
3620 "gemini" => quote! { #crate_path::agent::impls::GeminiAgent },
3621 _ => quote! { #crate_path::agent::impls::ClaudeCodeAgent },
3622 }
3623 };
3624
3625 let struct_def = quote! {
3627 #vis struct #struct_name<#inner_generic_ident = #default_agent_type> {
3628 inner: #inner_generic_ident,
3629 }
3630 };
3631
3632 let constructors = quote! {
3634 impl<#inner_generic_ident> #struct_name<#inner_generic_ident> {
3635 pub fn new(inner: #inner_generic_ident) -> Self {
3637 Self { inner }
3638 }
3639 }
3640 };
3641
3642 let (backend_constructors, default_impl) = if agent_attrs.default_inner.is_some() {
3644 let default_impl = quote! {
3646 impl Default for #struct_name {
3647 fn default() -> Self {
3648 Self {
3649 inner: <#default_agent_type as Default>::default(),
3650 }
3651 }
3652 }
3653 };
3654 (quote! {}, default_impl)
3655 } else {
3656 let backend_constructors = generate_backend_constructors(
3658 struct_name,
3659 &backend,
3660 model.as_deref(),
3661 profile.as_deref(),
3662 &crate_path,
3663 );
3664 let default_impl = generate_default_impl(
3665 struct_name,
3666 &backend,
3667 model.as_deref(),
3668 profile.as_deref(),
3669 &crate_path,
3670 );
3671 (backend_constructors, default_impl)
3672 };
3673
3674 let enhanced_expertise = if is_string_output {
3676 quote! { #expertise }
3678 } else {
3679 let type_name = quote!(#output_type).to_string();
3681 quote! {
3682 {
3683 use std::sync::OnceLock;
3684 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3685
3686 EXPERTISE_CACHE.get_or_init(|| {
3687 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3689
3690 if schema.is_empty() {
3691 format!(
3693 concat!(
3694 #expertise,
3695 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3696 "Do not include any text outside the JSON object."
3697 ),
3698 #type_name
3699 )
3700 } else {
3701 format!(
3703 concat!(
3704 #expertise,
3705 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3706 ),
3707 schema
3708 )
3709 }
3710 }).as_str()
3711 }
3712 }
3713 };
3714
3715 let agent_impl = quote! {
3717 #[async_trait::async_trait]
3718 impl<#inner_generic_ident> #crate_path::agent::Agent for #struct_name<#inner_generic_ident>
3719 where
3720 #inner_generic_ident: #crate_path::agent::Agent<Output = String>,
3721 {
3722 type Output = #output_type;
3723
3724 fn expertise(&self) -> &str {
3725 #enhanced_expertise
3726 }
3727
3728 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3729 let enhanced_payload = intent.prepend_text(self.expertise());
3731
3732 let response = self.inner.execute(enhanced_payload).await?;
3734
3735 let json_str = #crate_path::extract_json(&response)
3737 .map_err(|e| #crate_path::agent::AgentError::ParseError {
3738 message: e.to_string(),
3739 reason: #crate_path::agent::error::ParseErrorReason::MarkdownExtractionFailed,
3740 })?;
3741
3742 serde_json::from_str(&json_str).map_err(|e| {
3744 let reason = if e.is_eof() {
3745 #crate_path::agent::error::ParseErrorReason::UnexpectedEof
3746 } else if e.is_syntax() {
3747 #crate_path::agent::error::ParseErrorReason::InvalidJson
3748 } else {
3749 #crate_path::agent::error::ParseErrorReason::SchemaMismatch
3750 };
3751 #crate_path::agent::AgentError::ParseError {
3752 message: e.to_string(),
3753 reason,
3754 }
3755 })
3756 }
3757
3758 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3759 self.inner.is_available().await
3760 }
3761 }
3762 };
3763
3764 let expanded = quote! {
3765 #struct_def
3766 #constructors
3767 #backend_constructors
3768 #default_impl
3769 #agent_impl
3770 };
3771
3772 TokenStream::from(expanded)
3773}
3774
3775#[proc_macro_derive(TypeMarker)]
3797pub fn derive_type_marker(input: TokenStream) -> TokenStream {
3798 let input = parse_macro_input!(input as DeriveInput);
3799 let struct_name = &input.ident;
3800 let type_name_str = struct_name.to_string();
3801
3802 let found_crate =
3804 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3805 let crate_path = match found_crate {
3806 FoundCrate::Itself => {
3807 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3808 quote!(::#ident)
3809 }
3810 FoundCrate::Name(name) => {
3811 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3812 quote!(::#ident)
3813 }
3814 };
3815
3816 let expanded = quote! {
3817 impl #crate_path::orchestrator::TypeMarker for #struct_name {
3818 const TYPE_NAME: &'static str = #type_name_str;
3819 }
3820 };
3821
3822 TokenStream::from(expanded)
3823}
3824
3825#[proc_macro_attribute]
3861pub fn type_marker(_attr: TokenStream, item: TokenStream) -> TokenStream {
3862 let input = parse_macro_input!(item as syn::DeriveInput);
3863 let struct_name = &input.ident;
3864 let vis = &input.vis;
3865 let type_name_str = struct_name.to_string();
3866
3867 let default_fn_name = syn::Ident::new(
3869 &format!("default_{}_type", to_snake_case(&type_name_str)),
3870 struct_name.span(),
3871 );
3872
3873 let found_crate =
3875 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3876 let crate_path = match found_crate {
3877 FoundCrate::Itself => {
3878 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3879 quote!(::#ident)
3880 }
3881 FoundCrate::Name(name) => {
3882 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3883 quote!(::#ident)
3884 }
3885 };
3886
3887 let fields = match &input.data {
3889 syn::Data::Struct(data_struct) => match &data_struct.fields {
3890 syn::Fields::Named(fields) => &fields.named,
3891 _ => {
3892 return syn::Error::new_spanned(
3893 struct_name,
3894 "type_marker only works with structs with named fields",
3895 )
3896 .to_compile_error()
3897 .into();
3898 }
3899 },
3900 _ => {
3901 return syn::Error::new_spanned(struct_name, "type_marker only works with structs")
3902 .to_compile_error()
3903 .into();
3904 }
3905 };
3906
3907 let mut new_fields = vec![];
3909
3910 let default_fn_name_str = default_fn_name.to_string();
3912 let default_fn_name_lit = syn::LitStr::new(&default_fn_name_str, default_fn_name.span());
3913
3914 new_fields.push(quote! {
3919 #[serde(default = #default_fn_name_lit)]
3920 __type: String
3921 });
3922
3923 for field in fields {
3925 new_fields.push(quote! { #field });
3926 }
3927
3928 let attrs = &input.attrs;
3930 let generics = &input.generics;
3931
3932 let expanded = quote! {
3933 fn #default_fn_name() -> String {
3935 #type_name_str.to_string()
3936 }
3937
3938 #(#attrs)*
3940 #vis struct #struct_name #generics {
3941 #(#new_fields),*
3942 }
3943
3944 impl #crate_path::orchestrator::TypeMarker for #struct_name {
3946 const TYPE_NAME: &'static str = #type_name_str;
3947 }
3948 };
3949
3950 TokenStream::from(expanded)
3951}