1use proc_macro::TokenStream;
2use proc_macro_crate::{FoundCrate, crate_name};
3use quote::quote;
4use regex::Regex;
5use syn::{
6 Data, DeriveInput, Meta, Token,
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 punctuated::Punctuated,
10};
11
12fn parse_template_placeholders_with_mode(template: &str) -> Vec<(String, Option<String>)> {
15 let mut placeholders = Vec::new();
16 let mut seen_fields = std::collections::HashSet::new();
17
18 let mode_pattern = Regex::new(r"\{\{\s*(\w+)\s*:\s*(\w+)\s*\}\}").unwrap();
20 for cap in mode_pattern.captures_iter(template) {
21 let field_name = cap[1].to_string();
22 let mode = cap[2].to_string();
23 placeholders.push((field_name.clone(), Some(mode)));
24 seen_fields.insert(field_name);
25 }
26
27 let standard_pattern = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
29 for cap in standard_pattern.captures_iter(template) {
30 let field_name = cap[1].to_string();
31 if !seen_fields.contains(&field_name) {
33 placeholders.push((field_name, None));
34 }
35 }
36
37 placeholders
38}
39
40fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
42 attrs
43 .iter()
44 .filter_map(|attr| {
45 if attr.path().is_ident("doc")
46 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
47 && let syn::Expr::Lit(syn::ExprLit {
48 lit: syn::Lit::Str(lit_str),
49 ..
50 }) = &meta_name_value.value
51 {
52 return Some(lit_str.value());
53 }
54 None
55 })
56 .map(|s| s.trim().to_string())
57 .collect::<Vec<_>>()
58 .join(" ")
59}
60
61fn generate_example_only_parts(
63 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
64 has_default: bool,
65 crate_path: &proc_macro2::TokenStream,
66) -> proc_macro2::TokenStream {
67 let mut field_values = Vec::new();
68
69 for field in fields.iter() {
70 let field_name = field.ident.as_ref().unwrap();
71 let field_name_str = field_name.to_string();
72 let attrs = parse_field_prompt_attrs(&field.attrs);
73
74 if field_name_str == "__type" {
78 continue;
79 }
80
81 if attrs.skip {
83 continue;
84 }
85
86 if let Some(example) = attrs.example {
88 field_values.push(quote! {
90 json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
91 });
92 } else if has_default {
93 field_values.push(quote! {
95 let default_value = serde_json::to_value(&default_instance.#field_name)
96 .unwrap_or(serde_json::Value::Null);
97 json_obj.insert(#field_name_str.to_string(), default_value);
98 });
99 } else {
100 field_values.push(quote! {
102 let value = serde_json::to_value(&self.#field_name)
103 .unwrap_or(serde_json::Value::Null);
104 json_obj.insert(#field_name_str.to_string(), value);
105 });
106 }
107 }
108
109 if has_default {
110 quote! {
111 {
112 let default_instance = Self::default();
113 let mut json_obj = serde_json::Map::new();
114 #(#field_values)*
115 let json_value = serde_json::Value::Object(json_obj);
116 let json_str = serde_json::to_string_pretty(&json_value)
117 .unwrap_or_else(|_| "{}".to_string());
118 vec![#crate_path::prompt::PromptPart::Text(json_str)]
119 }
120 }
121 } else {
122 quote! {
123 {
124 let mut json_obj = serde_json::Map::new();
125 #(#field_values)*
126 let json_value = serde_json::Value::Object(json_obj);
127 let json_str = serde_json::to_string_pretty(&json_value)
128 .unwrap_or_else(|_| "{}".to_string());
129 vec![#crate_path::prompt::PromptPart::Text(json_str)]
130 }
131 }
132 }
133}
134
135fn generate_schema_only_parts(
137 struct_name: &str,
138 struct_docs: &str,
139 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
140 crate_path: &proc_macro2::TokenStream,
141 _has_type_marker: bool,
142) -> proc_macro2::TokenStream {
143 let mut field_schema_parts = vec![];
144 let mut nested_type_collectors = vec![];
145
146 for field in fields.iter() {
148 let field_name = field.ident.as_ref().unwrap();
149 let field_name_str = field_name.to_string();
150 let attrs = parse_field_prompt_attrs(&field.attrs);
151
152 if field_name_str == "__type" {
156 continue;
157 }
158
159 if attrs.skip {
161 continue;
162 }
163
164 let field_docs = extract_doc_comments(&field.attrs);
166
167 let (is_vec, inner_type) = extract_vec_inner_type(&field.ty);
169
170 if is_vec {
171 let comment = if !field_docs.is_empty() {
174 format!(" // {}", field_docs)
175 } else {
176 String::new()
177 };
178
179 field_schema_parts.push(quote! {
180 {
181 let type_name = stringify!(#inner_type);
182 format!(" {}: {}[];{}", #field_name_str, type_name, #comment)
183 }
184 });
185
186 if let Some(inner) = inner_type
188 && !is_primitive_type(inner)
189 {
190 nested_type_collectors.push(quote! {
191 <#inner as #crate_path::prompt::ToPrompt>::prompt_schema()
192 });
193 }
194 } else {
195 let field_type = &field.ty;
197 let is_primitive = is_primitive_type(field_type);
198
199 if !is_primitive {
200 let comment = if !field_docs.is_empty() {
203 format!(" // {}", field_docs)
204 } else {
205 String::new()
206 };
207
208 field_schema_parts.push(quote! {
209 {
210 let type_name = stringify!(#field_type);
211 format!(" {}: {};{}", #field_name_str, type_name, #comment)
212 }
213 });
214
215 nested_type_collectors.push(quote! {
217 <#field_type as #crate_path::prompt::ToPrompt>::prompt_schema()
218 });
219 } else {
220 let type_str = format_type_for_schema(&field.ty);
223 let comment = if !field_docs.is_empty() {
224 format!(" // {}", field_docs)
225 } else {
226 String::new()
227 };
228
229 field_schema_parts.push(quote! {
230 format!(" {}: {};{}", #field_name_str, #type_str, #comment)
231 });
232 }
233 }
234 }
235
236 let mut header_lines = Vec::new();
251
252 if !struct_docs.is_empty() {
254 header_lines.push("/**".to_string());
255 header_lines.push(format!(" * {}", struct_docs));
256 header_lines.push(" */".to_string());
257 }
258
259 header_lines.push(format!("type {} = {{", struct_name));
261
262 quote! {
263 {
264 let mut all_lines: Vec<String> = Vec::new();
265
266 let nested_schemas: Vec<String> = vec![#(#nested_type_collectors),*];
268 let mut seen_types = std::collections::HashSet::<String>::new();
269
270 for schema in nested_schemas {
271 if !schema.is_empty() {
272 if seen_types.insert(schema.clone()) {
274 all_lines.push(schema);
275 all_lines.push(String::new()); }
277 }
278 }
279
280 let mut lines: Vec<String> = Vec::new();
282 #(lines.push(#header_lines.to_string());)*
283 #(lines.push(#field_schema_parts);)*
284 lines.push("}".to_string());
285 all_lines.push(lines.join("\n"));
286
287 vec![#crate_path::prompt::PromptPart::Text(all_lines.join("\n"))]
288 }
289 }
290}
291
292fn extract_vec_inner_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
294 if let syn::Type::Path(type_path) = ty
295 && let Some(last_segment) = type_path.path.segments.last()
296 && last_segment.ident == "Vec"
297 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
298 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
299 {
300 return (true, Some(inner_type));
301 }
302 (false, None)
303}
304
305fn is_primitive_type(ty: &syn::Type) -> bool {
307 if let syn::Type::Path(type_path) = ty
308 && let Some(last_segment) = type_path.path.segments.last()
309 {
310 let type_name = last_segment.ident.to_string();
311 matches!(
312 type_name.as_str(),
313 "String"
314 | "str"
315 | "i8"
316 | "i16"
317 | "i32"
318 | "i64"
319 | "i128"
320 | "isize"
321 | "u8"
322 | "u16"
323 | "u32"
324 | "u64"
325 | "u128"
326 | "usize"
327 | "f32"
328 | "f64"
329 | "bool"
330 | "Vec"
331 | "Option"
332 | "HashMap"
333 | "BTreeMap"
334 | "HashSet"
335 | "BTreeSet"
336 )
337 } else {
338 true
340 }
341}
342
343fn format_type_for_schema(ty: &syn::Type) -> String {
345 match ty {
347 syn::Type::Path(type_path) => {
348 let path = &type_path.path;
349 if let Some(last_segment) = path.segments.last() {
350 let type_name = last_segment.ident.to_string();
351
352 if type_name == "Option"
354 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
355 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
356 {
357 return format!("{} | null", format_type_for_schema(inner_type));
358 }
359
360 match type_name.as_str() {
362 "String" | "str" => "string".to_string(),
363 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
364 | "u64" | "u128" | "usize" => "number".to_string(),
365 "f32" | "f64" => "number".to_string(),
366 "bool" => "boolean".to_string(),
367 "Vec" => {
368 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
369 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
370 {
371 return format!("{}[]", format_type_for_schema(inner_type));
372 }
373 "array".to_string()
374 }
375 _ => type_name.to_lowercase(),
376 }
377 } else {
378 "unknown".to_string()
379 }
380 }
381 _ => "unknown".to_string(),
382 }
383}
384
385#[derive(Default)]
387struct PromptAttributes {
388 skip: bool,
389 rename: Option<String>,
390 description: Option<String>,
391}
392
393fn parse_prompt_attributes(attrs: &[syn::Attribute]) -> PromptAttributes {
396 let mut result = PromptAttributes::default();
397
398 for attr in attrs {
399 if attr.path().is_ident("prompt") {
400 if let Ok(meta_list) = attr.meta.require_list() {
402 if let Ok(metas) =
404 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
405 {
406 for meta in metas {
407 if let Meta::NameValue(nv) = meta {
408 if nv.path.is_ident("rename") {
409 if let syn::Expr::Lit(syn::ExprLit {
410 lit: syn::Lit::Str(lit_str),
411 ..
412 }) = nv.value
413 {
414 result.rename = Some(lit_str.value());
415 }
416 } else if nv.path.is_ident("description")
417 && let syn::Expr::Lit(syn::ExprLit {
418 lit: syn::Lit::Str(lit_str),
419 ..
420 }) = nv.value
421 {
422 result.description = Some(lit_str.value());
423 }
424 } else if let Meta::Path(path) = meta
425 && path.is_ident("skip")
426 {
427 result.skip = true;
428 }
429 }
430 }
431
432 let tokens_str = meta_list.tokens.to_string();
434 if tokens_str == "skip" {
435 result.skip = true;
436 }
437 }
438
439 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
441 result.description = Some(lit_str.value());
442 }
443 }
444 }
445 result
446}
447
448fn parse_serde_variant_rename(attrs: &[syn::Attribute]) -> Option<String> {
450 for attr in attrs {
451 if attr.path().is_ident("serde")
452 && let Ok(meta_list) = attr.meta.require_list()
453 && let Ok(metas) =
454 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
455 {
456 for meta in metas {
457 if let Meta::NameValue(nv) = meta
458 && nv.path.is_ident("rename")
459 && let syn::Expr::Lit(syn::ExprLit {
460 lit: syn::Lit::Str(lit_str),
461 ..
462 }) = nv.value
463 {
464 return Some(lit_str.value());
465 }
466 }
467 }
468 }
469 None
470}
471
472#[derive(Debug, Clone, Copy, PartialEq, Eq)]
474enum RenameRule {
475 #[allow(dead_code)]
476 None,
477 LowerCase,
478 UpperCase,
479 PascalCase,
480 CamelCase,
481 SnakeCase,
482 ScreamingSnakeCase,
483 KebabCase,
484 ScreamingKebabCase,
485}
486
487impl RenameRule {
488 fn from_str(s: &str) -> Option<Self> {
490 match s {
491 "lowercase" => Some(Self::LowerCase),
492 "UPPERCASE" => Some(Self::UpperCase),
493 "PascalCase" => Some(Self::PascalCase),
494 "camelCase" => Some(Self::CamelCase),
495 "snake_case" => Some(Self::SnakeCase),
496 "SCREAMING_SNAKE_CASE" => Some(Self::ScreamingSnakeCase),
497 "kebab-case" => Some(Self::KebabCase),
498 "SCREAMING-KEBAB-CASE" => Some(Self::ScreamingKebabCase),
499 _ => None,
500 }
501 }
502
503 fn apply(&self, name: &str) -> String {
505 match self {
506 Self::None => name.to_string(),
507 Self::LowerCase => name.to_lowercase(),
508 Self::UpperCase => name.to_uppercase(),
509 Self::PascalCase => name.to_string(), Self::CamelCase => {
511 let mut chars = name.chars();
513 match chars.next() {
514 None => String::new(),
515 Some(first) => first.to_lowercase().chain(chars).collect(),
516 }
517 }
518 Self::SnakeCase => {
519 let mut result = String::new();
521 for (i, ch) in name.chars().enumerate() {
522 if ch.is_uppercase() && i > 0 {
523 result.push('_');
524 }
525 result.push(ch.to_lowercase().next().unwrap());
526 }
527 result
528 }
529 Self::ScreamingSnakeCase => {
530 let mut result = String::new();
532 for (i, ch) in name.chars().enumerate() {
533 if ch.is_uppercase() && i > 0 {
534 result.push('_');
535 }
536 result.push(ch.to_uppercase().next().unwrap());
537 }
538 result
539 }
540 Self::KebabCase => {
541 let mut result = String::new();
543 for (i, ch) in name.chars().enumerate() {
544 if ch.is_uppercase() && i > 0 {
545 result.push('-');
546 }
547 result.push(ch.to_lowercase().next().unwrap());
548 }
549 result
550 }
551 Self::ScreamingKebabCase => {
552 let mut result = String::new();
554 for (i, ch) in name.chars().enumerate() {
555 if ch.is_uppercase() && i > 0 {
556 result.push('-');
557 }
558 result.push(ch.to_uppercase().next().unwrap());
559 }
560 result
561 }
562 }
563 }
564}
565
566fn parse_serde_rename_all(attrs: &[syn::Attribute]) -> Option<RenameRule> {
568 for attr in attrs {
569 if attr.path().is_ident("serde")
570 && let Ok(meta_list) = attr.meta.require_list()
571 {
572 if let Ok(metas) =
574 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
575 {
576 for meta in metas {
577 if let Meta::NameValue(nv) = meta
578 && nv.path.is_ident("rename_all")
579 && let syn::Expr::Lit(syn::ExprLit {
580 lit: syn::Lit::Str(lit_str),
581 ..
582 }) = nv.value
583 {
584 return RenameRule::from_str(&lit_str.value());
585 }
586 }
587 }
588 }
589 }
590 None
591}
592
593#[derive(Debug, Default)]
595struct FieldPromptAttrs {
596 skip: bool,
597 rename: Option<String>,
598 format_with: Option<String>,
599 image: bool,
600 example: Option<String>,
601}
602
603fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
605 let mut result = FieldPromptAttrs::default();
606
607 for attr in attrs {
608 if attr.path().is_ident("prompt") {
609 if let Ok(meta_list) = attr.meta.require_list() {
611 if let Ok(metas) =
613 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
614 {
615 for meta in metas {
616 match meta {
617 Meta::Path(path) if path.is_ident("skip") => {
618 result.skip = true;
619 }
620 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
621 if let syn::Expr::Lit(syn::ExprLit {
622 lit: syn::Lit::Str(lit_str),
623 ..
624 }) = nv.value
625 {
626 result.rename = Some(lit_str.value());
627 }
628 }
629 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
630 if let syn::Expr::Lit(syn::ExprLit {
631 lit: syn::Lit::Str(lit_str),
632 ..
633 }) = nv.value
634 {
635 result.format_with = Some(lit_str.value());
636 }
637 }
638 Meta::Path(path) if path.is_ident("image") => {
639 result.image = true;
640 }
641 Meta::NameValue(nv) if nv.path.is_ident("example") => {
642 if let syn::Expr::Lit(syn::ExprLit {
643 lit: syn::Lit::Str(lit_str),
644 ..
645 }) = nv.value
646 {
647 result.example = Some(lit_str.value());
648 }
649 }
650 _ => {}
651 }
652 }
653 } else if meta_list.tokens.to_string() == "skip" {
654 result.skip = true;
656 } else if meta_list.tokens.to_string() == "image" {
657 result.image = true;
659 }
660 }
661 }
662 }
663
664 result
665}
666
667#[proc_macro_derive(ToPrompt, attributes(prompt))]
710pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
711 let input = parse_macro_input!(input as DeriveInput);
712
713 let found_crate =
714 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
715 let crate_path = match found_crate {
716 FoundCrate::Itself => {
717 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
719 quote!(::#ident)
720 }
721 FoundCrate::Name(name) => {
722 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
723 quote!(::#ident)
724 }
725 };
726
727 match &input.data {
729 Data::Enum(data_enum) => {
730 let enum_name = &input.ident;
732 let enum_docs = extract_doc_comments(&input.attrs);
733
734 let rename_rule = parse_serde_rename_all(&input.attrs);
736
737 let mut variant_lines = Vec::new();
750 let mut first_variant_name = None;
751
752 for variant in &data_enum.variants {
753 let variant_name = &variant.ident;
754 let variant_name_str = variant_name.to_string();
755
756 let prompt_attrs = parse_prompt_attributes(&variant.attrs);
758
759 if prompt_attrs.skip {
761 continue;
762 }
763
764 let variant_value = if let Some(prompt_rename) = &prompt_attrs.rename {
770 prompt_rename.clone()
771 } else if let Some(serde_rename) = parse_serde_variant_rename(&variant.attrs) {
772 serde_rename
773 } else if let Some(rule) = rename_rule {
774 rule.apply(&variant_name_str)
775 } else {
776 variant_name_str.clone()
777 };
778
779 let variant_line = match &variant.fields {
781 syn::Fields::Unit => {
782 if let Some(desc) = &prompt_attrs.description {
784 format!(" | \"{}\" // {}", variant_value, desc)
785 } else {
786 let docs = extract_doc_comments(&variant.attrs);
787 if !docs.is_empty() {
788 format!(" | \"{}\" // {}", variant_value, docs)
789 } else {
790 format!(" | \"{}\"", variant_value)
791 }
792 }
793 }
794 syn::Fields::Named(fields) => {
795 let mut field_parts = vec![format!("type: \"{}\"", variant_value)];
797
798 for field in &fields.named {
799 let field_name = field.ident.as_ref().unwrap().to_string();
800 let field_type = format_type_for_schema(&field.ty);
801 field_parts.push(format!("{}: {}", field_name, field_type));
802 }
803
804 let field_str = field_parts.join(", ");
805 let comment = if let Some(desc) = &prompt_attrs.description {
806 format!(" // {}", desc)
807 } else {
808 let docs = extract_doc_comments(&variant.attrs);
809 if !docs.is_empty() {
810 format!(" // {}", docs)
811 } else {
812 String::new()
813 }
814 };
815
816 format!(" | {{ {} }}{}", field_str, comment)
817 }
818 syn::Fields::Unnamed(fields) => {
819 let field_types: Vec<String> = fields
821 .unnamed
822 .iter()
823 .map(|f| format_type_for_schema(&f.ty))
824 .collect();
825
826 let tuple_str = field_types.join(", ");
827 let comment = if let Some(desc) = &prompt_attrs.description {
828 format!(" // {}", desc)
829 } else {
830 let docs = extract_doc_comments(&variant.attrs);
831 if !docs.is_empty() {
832 format!(" // {}", docs)
833 } else {
834 String::new()
835 }
836 };
837
838 format!(" | [{}]{}", tuple_str, comment)
839 }
840 };
841
842 variant_lines.push(variant_line);
843
844 if first_variant_name.is_none() {
845 first_variant_name = Some(variant_value);
846 }
847 }
848
849 let mut lines = Vec::new();
851
852 if !enum_docs.is_empty() {
854 lines.push("/**".to_string());
855 lines.push(format!(" * {}", enum_docs));
856 lines.push(" */".to_string());
857 }
858
859 lines.push(format!("type {} =", enum_name));
861
862 for line in &variant_lines {
864 lines.push(line.clone());
865 }
866
867 if let Some(last) = lines.last_mut()
869 && !last.ends_with(';')
870 {
871 last.push(';');
872 }
873
874 if let Some(first_name) = first_variant_name {
876 lines.push("".to_string()); lines.push(format!("Example value: \"{}\"", first_name));
878 }
879
880 let prompt_string = lines.join("\n");
881 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
882
883 let mut match_arms = Vec::new();
885 for variant in &data_enum.variants {
886 let variant_name = &variant.ident;
887 let variant_name_str = variant_name.to_string();
888
889 let prompt_attrs = parse_prompt_attributes(&variant.attrs);
891
892 let variant_value = if let Some(prompt_rename) = &prompt_attrs.rename {
898 prompt_rename.clone()
899 } else if let Some(serde_rename) = parse_serde_variant_rename(&variant.attrs) {
900 serde_rename
901 } else if let Some(rule) = rename_rule {
902 rule.apply(&variant_name_str)
903 } else {
904 variant_name_str.clone()
905 };
906
907 match &variant.fields {
909 syn::Fields::Unit => {
910 if prompt_attrs.skip {
912 match_arms.push(quote! {
913 Self::#variant_name => stringify!(#variant_name).to_string()
914 });
915 } else if let Some(desc) = &prompt_attrs.description {
916 match_arms.push(quote! {
917 Self::#variant_name => format!("{}: {}", #variant_value, #desc)
918 });
919 } else {
920 let variant_docs = extract_doc_comments(&variant.attrs);
921 if !variant_docs.is_empty() {
922 match_arms.push(quote! {
923 Self::#variant_name => format!("{}: {}", #variant_value, #variant_docs)
924 });
925 } else {
926 match_arms.push(quote! {
927 Self::#variant_name => #variant_value.to_string()
928 });
929 }
930 }
931 }
932 syn::Fields::Named(fields) => {
933 let field_bindings: Vec<_> = fields
935 .named
936 .iter()
937 .map(|f| f.ident.as_ref().unwrap())
938 .collect();
939
940 let field_displays: Vec<_> = fields
941 .named
942 .iter()
943 .map(|f| {
944 let field_name = f.ident.as_ref().unwrap();
945 let field_name_str = field_name.to_string();
946 quote! {
947 format!("{}: {:?}", #field_name_str, #field_name)
948 }
949 })
950 .collect();
951
952 let doc_or_desc = if let Some(desc) = &prompt_attrs.description {
953 desc.clone()
954 } else {
955 let docs = extract_doc_comments(&variant.attrs);
956 if !docs.is_empty() {
957 docs
958 } else {
959 String::new()
960 }
961 };
962
963 if doc_or_desc.is_empty() {
964 match_arms.push(quote! {
965 Self::#variant_name { #(#field_bindings),* } => {
966 let fields = vec![#(#field_displays),*];
967 format!("{} {{ {} }}", #variant_value, fields.join(", "))
968 }
969 });
970 } else {
971 match_arms.push(quote! {
972 Self::#variant_name { #(#field_bindings),* } => {
973 let fields = vec![#(#field_displays),*];
974 format!("{}: {} {{ {} }}", #variant_value, #doc_or_desc, fields.join(", "))
975 }
976 });
977 }
978 }
979 syn::Fields::Unnamed(fields) => {
980 let field_count = fields.unnamed.len();
982 let field_bindings: Vec<_> = (0..field_count)
983 .map(|i| {
984 syn::Ident::new(
985 &format!("field{}", i),
986 proc_macro2::Span::call_site(),
987 )
988 })
989 .collect();
990
991 let field_displays: Vec<_> = field_bindings
992 .iter()
993 .map(|field_name| {
994 quote! {
995 format!("{:?}", #field_name)
996 }
997 })
998 .collect();
999
1000 let doc_or_desc = if let Some(desc) = &prompt_attrs.description {
1001 desc.clone()
1002 } else {
1003 let docs = extract_doc_comments(&variant.attrs);
1004 if !docs.is_empty() {
1005 docs
1006 } else {
1007 String::new()
1008 }
1009 };
1010
1011 if doc_or_desc.is_empty() {
1012 match_arms.push(quote! {
1013 Self::#variant_name(#(#field_bindings),*) => {
1014 let fields = vec![#(#field_displays),*];
1015 format!("{}({})", #variant_value, fields.join(", "))
1016 }
1017 });
1018 } else {
1019 match_arms.push(quote! {
1020 Self::#variant_name(#(#field_bindings),*) => {
1021 let fields = vec![#(#field_displays),*];
1022 format!("{}: {}({})", #variant_value, #doc_or_desc, fields.join(", "))
1023 }
1024 });
1025 }
1026 }
1027 }
1028 }
1029
1030 let to_prompt_impl = if match_arms.is_empty() {
1031 quote! {
1033 fn to_prompt(&self) -> String {
1034 match *self {}
1035 }
1036 }
1037 } else {
1038 quote! {
1039 fn to_prompt(&self) -> String {
1040 match self {
1041 #(#match_arms),*
1042 }
1043 }
1044 }
1045 };
1046
1047 let expanded = quote! {
1048 impl #impl_generics #crate_path::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
1049 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1050 vec![#crate_path::prompt::PromptPart::Text(self.to_prompt())]
1051 }
1052
1053 #to_prompt_impl
1054
1055 fn prompt_schema() -> String {
1056 #prompt_string.to_string()
1057 }
1058 }
1059 };
1060
1061 TokenStream::from(expanded)
1062 }
1063 Data::Struct(data_struct) => {
1064 let mut template_attr = None;
1066 let mut template_file_attr = None;
1067 let mut mode_attr = None;
1068 let mut validate_attr = false;
1069 let mut type_marker_attr = false;
1070
1071 for attr in &input.attrs {
1072 if attr.path().is_ident("prompt") {
1073 if let Ok(metas) =
1075 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1076 {
1077 for meta in metas {
1078 match meta {
1079 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1080 if let syn::Expr::Lit(expr_lit) = nv.value
1081 && let syn::Lit::Str(lit_str) = expr_lit.lit
1082 {
1083 template_attr = Some(lit_str.value());
1084 }
1085 }
1086 Meta::NameValue(nv) if nv.path.is_ident("template_file") => {
1087 if let syn::Expr::Lit(expr_lit) = nv.value
1088 && let syn::Lit::Str(lit_str) = expr_lit.lit
1089 {
1090 template_file_attr = Some(lit_str.value());
1091 }
1092 }
1093 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
1094 if let syn::Expr::Lit(expr_lit) = nv.value
1095 && let syn::Lit::Str(lit_str) = expr_lit.lit
1096 {
1097 mode_attr = Some(lit_str.value());
1098 }
1099 }
1100 Meta::NameValue(nv) if nv.path.is_ident("validate") => {
1101 if let syn::Expr::Lit(expr_lit) = nv.value
1102 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
1103 {
1104 validate_attr = lit_bool.value();
1105 }
1106 }
1107 Meta::NameValue(nv) if nv.path.is_ident("type_marker") => {
1108 if let syn::Expr::Lit(expr_lit) = nv.value
1109 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
1110 {
1111 type_marker_attr = lit_bool.value();
1112 }
1113 }
1114 Meta::Path(path) if path.is_ident("type_marker") => {
1115 type_marker_attr = true;
1117 }
1118 _ => {}
1119 }
1120 }
1121 }
1122 }
1123 }
1124
1125 if template_attr.is_some() && template_file_attr.is_some() {
1127 return syn::Error::new(
1128 input.ident.span(),
1129 "The `template` and `template_file` attributes are mutually exclusive. Please use only one.",
1130 ).to_compile_error().into();
1131 }
1132
1133 let template_str = if let Some(file_path) = template_file_attr {
1135 let mut full_path = None;
1139
1140 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
1142 let is_trybuild = manifest_dir.contains("target/tests/trybuild");
1144
1145 if !is_trybuild {
1146 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
1148 if candidate.exists() {
1149 full_path = Some(candidate);
1150 }
1151 } else {
1152 if let Some(target_pos) = manifest_dir.find("/target/tests/trybuild") {
1158 let workspace_root = &manifest_dir[..target_pos];
1159 let original_macros_dir = std::path::Path::new(workspace_root)
1161 .join("crates")
1162 .join("llm-toolkit-macros");
1163
1164 let candidate = original_macros_dir.join(&file_path);
1165 if candidate.exists() {
1166 full_path = Some(candidate);
1167 }
1168 }
1169 }
1170 }
1171
1172 if full_path.is_none() {
1174 let candidate = std::path::Path::new(&file_path).to_path_buf();
1175 if candidate.exists() {
1176 full_path = Some(candidate);
1177 }
1178 }
1179
1180 if full_path.is_none()
1183 && let Ok(current_dir) = std::env::current_dir()
1184 {
1185 let mut search_dir = current_dir.as_path();
1186 for _ in 0..10 {
1188 let macros_dir = search_dir.join("crates/llm-toolkit-macros");
1190 if macros_dir.exists() {
1191 let candidate = macros_dir.join(&file_path);
1192 if candidate.exists() {
1193 full_path = Some(candidate);
1194 break;
1195 }
1196 }
1197 let candidate = search_dir.join(&file_path);
1199 if candidate.exists() {
1200 full_path = Some(candidate);
1201 break;
1202 }
1203 if let Some(parent) = search_dir.parent() {
1204 search_dir = parent;
1205 } else {
1206 break;
1207 }
1208 }
1209 }
1210
1211 if full_path.is_none() {
1213 let mut error_msg = format!(
1215 "Template file '{}' not found at compile time.\n\nSearched in:",
1216 file_path
1217 );
1218
1219 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
1220 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
1221 error_msg.push_str(&format!("\n - {}", candidate.display()));
1222 }
1223
1224 if let Ok(current_dir) = std::env::current_dir() {
1225 let candidate = current_dir.join(&file_path);
1226 error_msg.push_str(&format!("\n - {}", candidate.display()));
1227 }
1228
1229 error_msg.push_str("\n\nPlease ensure:");
1230 error_msg.push_str("\n 1. The template file exists");
1231 error_msg.push_str("\n 2. The path is relative to CARGO_MANIFEST_DIR");
1232 error_msg.push_str("\n 3. There are no typos in the path");
1233
1234 return syn::Error::new(input.ident.span(), error_msg)
1235 .to_compile_error()
1236 .into();
1237 }
1238
1239 let final_path = full_path.unwrap();
1240
1241 match std::fs::read_to_string(&final_path) {
1243 Ok(content) => Some(content),
1244 Err(e) => {
1245 return syn::Error::new(
1246 input.ident.span(),
1247 format!(
1248 "Failed to read template file '{}': {}\n\nPath resolved to: {}",
1249 file_path,
1250 e,
1251 final_path.display()
1252 ),
1253 )
1254 .to_compile_error()
1255 .into();
1256 }
1257 }
1258 } else {
1259 template_attr
1260 };
1261
1262 if validate_attr && let Some(template) = &template_str {
1264 let mut env = minijinja::Environment::new();
1266 if let Err(e) = env.add_template("validation", template) {
1267 let warning_msg =
1269 format!("Template validation warning: Invalid Jinja syntax - {}", e);
1270 let warning_ident = syn::Ident::new(
1271 "TEMPLATE_VALIDATION_WARNING",
1272 proc_macro2::Span::call_site(),
1273 );
1274 let _warning_tokens = quote! {
1275 #[deprecated(note = #warning_msg)]
1276 const #warning_ident: () = ();
1277 let _ = #warning_ident;
1278 };
1279 eprintln!("cargo:warning={}", warning_msg);
1281 }
1282
1283 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1285 &fields.named
1286 } else {
1287 panic!("Template validation is only supported for structs with named fields.");
1288 };
1289
1290 let field_names: std::collections::HashSet<String> = fields
1291 .iter()
1292 .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
1293 .collect();
1294
1295 let placeholders = parse_template_placeholders_with_mode(template);
1297
1298 for (placeholder_name, _mode) in &placeholders {
1299 if placeholder_name != "self" && !field_names.contains(placeholder_name) {
1300 let warning_msg = format!(
1301 "Template validation warning: Variable '{}' used in template but not found in struct fields",
1302 placeholder_name
1303 );
1304 eprintln!("cargo:warning={}", warning_msg);
1305 }
1306 }
1307 }
1308
1309 let name = input.ident;
1310 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1311
1312 let struct_docs = extract_doc_comments(&input.attrs);
1314
1315 let is_mode_based =
1317 mode_attr.is_some() || (template_str.is_none() && struct_docs.contains("mode"));
1318
1319 let expanded = if is_mode_based || mode_attr.is_some() {
1320 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1322 &fields.named
1323 } else {
1324 panic!(
1325 "Mode-based prompt generation is only supported for structs with named fields."
1326 );
1327 };
1328
1329 let struct_name_str = name.to_string();
1330
1331 let has_default = input.attrs.iter().any(|attr| {
1333 if attr.path().is_ident("derive")
1334 && let Ok(meta_list) = attr.meta.require_list()
1335 {
1336 let tokens_str = meta_list.tokens.to_string();
1337 tokens_str.contains("Default")
1338 } else {
1339 false
1340 }
1341 });
1342
1343 let schema_parts = generate_schema_only_parts(
1354 &struct_name_str,
1355 &struct_docs,
1356 fields,
1357 &crate_path,
1358 type_marker_attr,
1359 );
1360
1361 let example_parts = generate_example_only_parts(fields, has_default, &crate_path);
1363
1364 quote! {
1365 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1366 fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<#crate_path::prompt::PromptPart> {
1367 match mode {
1368 "schema_only" => #schema_parts,
1369 "example_only" => #example_parts,
1370 "full" | _ => {
1371 let mut parts = Vec::new();
1373
1374 let schema_parts = #schema_parts;
1376 parts.extend(schema_parts);
1377
1378 parts.push(#crate_path::prompt::PromptPart::Text("\n### Example".to_string()));
1380 parts.push(#crate_path::prompt::PromptPart::Text(
1381 format!("Here is an example of a valid `{}` object:", #struct_name_str)
1382 ));
1383
1384 let example_parts = #example_parts;
1386 parts.extend(example_parts);
1387
1388 parts
1389 }
1390 }
1391 }
1392
1393 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1394 self.to_prompt_parts_with_mode("full")
1395 }
1396
1397 fn to_prompt(&self) -> String {
1398 self.to_prompt_parts()
1399 .into_iter()
1400 .filter_map(|part| match part {
1401 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1402 _ => None,
1403 })
1404 .collect::<Vec<_>>()
1405 .join("\n")
1406 }
1407
1408 fn prompt_schema() -> String {
1409 use std::sync::OnceLock;
1410 static SCHEMA_CACHE: OnceLock<String> = OnceLock::new();
1411
1412 SCHEMA_CACHE.get_or_init(|| {
1413 let schema_parts = #schema_parts;
1414 schema_parts
1415 .into_iter()
1416 .filter_map(|part| match part {
1417 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1418 _ => None,
1419 })
1420 .collect::<Vec<_>>()
1421 .join("\n")
1422 }).clone()
1423 }
1424 }
1425 }
1426 } else if let Some(template) = template_str {
1427 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1430 &fields.named
1431 } else {
1432 panic!(
1433 "Template prompt generation is only supported for structs with named fields."
1434 );
1435 };
1436
1437 let placeholders = parse_template_placeholders_with_mode(&template);
1439 let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
1441 mode.is_some()
1442 && fields
1443 .iter()
1444 .any(|f| f.ident.as_ref().unwrap() == field_name)
1445 });
1446
1447 let mut image_field_parts = Vec::new();
1448 for f in fields.iter() {
1449 let field_name = f.ident.as_ref().unwrap();
1450 let attrs = parse_field_prompt_attrs(&f.attrs);
1451
1452 if attrs.image {
1453 image_field_parts.push(quote! {
1455 parts.extend(self.#field_name.to_prompt_parts());
1456 });
1457 }
1458 }
1459
1460 if has_mode_syntax {
1462 let mut context_fields = Vec::new();
1464 let mut modified_template = template.clone();
1465
1466 for (field_name, mode_opt) in &placeholders {
1468 if let Some(mode) = mode_opt {
1469 let unique_key = format!("{}__{}", field_name, mode);
1471
1472 let pattern = format!("{{{{ {}:{} }}}}", field_name, mode);
1474 let replacement = format!("{{{{ {} }}}}", unique_key);
1475 modified_template = modified_template.replace(&pattern, &replacement);
1476
1477 let field_ident =
1479 syn::Ident::new(field_name, proc_macro2::Span::call_site());
1480
1481 context_fields.push(quote! {
1483 context.insert(
1484 #unique_key.to_string(),
1485 minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
1486 );
1487 });
1488 }
1489 }
1490
1491 for field in fields.iter() {
1493 let field_name = field.ident.as_ref().unwrap();
1494 let field_name_str = field_name.to_string();
1495
1496 let has_mode_entry = placeholders
1498 .iter()
1499 .any(|(name, mode)| name == &field_name_str && mode.is_some());
1500
1501 if !has_mode_entry {
1502 let is_primitive = match &field.ty {
1505 syn::Type::Path(type_path) => {
1506 if let Some(segment) = type_path.path.segments.last() {
1507 let type_name = segment.ident.to_string();
1508 matches!(
1509 type_name.as_str(),
1510 "String"
1511 | "str"
1512 | "i8"
1513 | "i16"
1514 | "i32"
1515 | "i64"
1516 | "i128"
1517 | "isize"
1518 | "u8"
1519 | "u16"
1520 | "u32"
1521 | "u64"
1522 | "u128"
1523 | "usize"
1524 | "f32"
1525 | "f64"
1526 | "bool"
1527 | "char"
1528 )
1529 } else {
1530 false
1531 }
1532 }
1533 _ => false,
1534 };
1535
1536 if is_primitive {
1537 context_fields.push(quote! {
1538 context.insert(
1539 #field_name_str.to_string(),
1540 minijinja::Value::from_serialize(&self.#field_name)
1541 );
1542 });
1543 } else {
1544 context_fields.push(quote! {
1546 context.insert(
1547 #field_name_str.to_string(),
1548 minijinja::Value::from(self.#field_name.to_prompt())
1549 );
1550 });
1551 }
1552 }
1553 }
1554
1555 quote! {
1556 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1557 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1558 let mut parts = Vec::new();
1559
1560 #(#image_field_parts)*
1562
1563 let text = {
1565 let mut env = minijinja::Environment::new();
1566 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1567 panic!("Failed to parse template: {}", e)
1568 });
1569
1570 let tmpl = env.get_template("prompt").unwrap();
1571
1572 let mut context = std::collections::HashMap::new();
1573 #(#context_fields)*
1574
1575 tmpl.render(context).unwrap_or_else(|e| {
1576 format!("Failed to render prompt: {}", e)
1577 })
1578 };
1579
1580 if !text.is_empty() {
1581 parts.push(#crate_path::prompt::PromptPart::Text(text));
1582 }
1583
1584 parts
1585 }
1586
1587 fn to_prompt(&self) -> String {
1588 let mut env = minijinja::Environment::new();
1590 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1591 panic!("Failed to parse template: {}", e)
1592 });
1593
1594 let tmpl = env.get_template("prompt").unwrap();
1595
1596 let mut context = std::collections::HashMap::new();
1597 #(#context_fields)*
1598
1599 tmpl.render(context).unwrap_or_else(|e| {
1600 format!("Failed to render prompt: {}", e)
1601 })
1602 }
1603
1604 fn prompt_schema() -> String {
1605 String::new() }
1607 }
1608 }
1609 } else {
1610 quote! {
1612 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1613 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1614 let mut parts = Vec::new();
1615
1616 #(#image_field_parts)*
1618
1619 let text = #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1621 format!("Failed to render prompt: {}", e)
1622 });
1623 if !text.is_empty() {
1624 parts.push(#crate_path::prompt::PromptPart::Text(text));
1625 }
1626
1627 parts
1628 }
1629
1630 fn to_prompt(&self) -> String {
1631 #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1632 format!("Failed to render prompt: {}", e)
1633 })
1634 }
1635
1636 fn prompt_schema() -> String {
1637 String::new() }
1639 }
1640 }
1641 }
1642 } else {
1643 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1646 &fields.named
1647 } else {
1648 panic!(
1649 "Default prompt generation is only supported for structs with named fields."
1650 );
1651 };
1652
1653 let mut text_field_parts = Vec::new();
1655 let mut image_field_parts = Vec::new();
1656
1657 for f in fields.iter() {
1658 let field_name = f.ident.as_ref().unwrap();
1659 let attrs = parse_field_prompt_attrs(&f.attrs);
1660
1661 if attrs.skip {
1663 continue;
1664 }
1665
1666 if attrs.image {
1667 image_field_parts.push(quote! {
1669 parts.extend(self.#field_name.to_prompt_parts());
1670 });
1671 } else {
1672 let key = if let Some(rename) = attrs.rename {
1678 rename
1679 } else {
1680 let doc_comment = extract_doc_comments(&f.attrs);
1681 if !doc_comment.is_empty() {
1682 doc_comment
1683 } else {
1684 field_name.to_string()
1685 }
1686 };
1687
1688 let value_expr = if let Some(format_with) = attrs.format_with {
1690 let func_path: syn::Path =
1692 syn::parse_str(&format_with).unwrap_or_else(|_| {
1693 panic!("Invalid function path: {}", format_with)
1694 });
1695 quote! { #func_path(&self.#field_name) }
1696 } else {
1697 quote! { self.#field_name.to_prompt() }
1698 };
1699
1700 text_field_parts.push(quote! {
1701 text_parts.push(format!("{}: {}", #key, #value_expr));
1702 });
1703 }
1704 }
1705
1706 quote! {
1708 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1709 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1710 let mut parts = Vec::new();
1711
1712 #(#image_field_parts)*
1714
1715 let mut text_parts = Vec::new();
1717 #(#text_field_parts)*
1718
1719 if !text_parts.is_empty() {
1720 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1721 }
1722
1723 parts
1724 }
1725
1726 fn to_prompt(&self) -> String {
1727 let mut text_parts = Vec::new();
1728 #(#text_field_parts)*
1729 text_parts.join("\n")
1730 }
1731
1732 fn prompt_schema() -> String {
1733 String::new() }
1735 }
1736 }
1737 };
1738
1739 TokenStream::from(expanded)
1740 }
1741 Data::Union(_) => {
1742 panic!("`#[derive(ToPrompt)]` is not supported for unions");
1743 }
1744 }
1745}
1746
1747#[derive(Debug, Clone)]
1749struct TargetInfo {
1750 name: String,
1751 template: Option<String>,
1752 field_configs: std::collections::HashMap<String, FieldTargetConfig>,
1753}
1754
1755#[derive(Debug, Clone, Default)]
1757struct FieldTargetConfig {
1758 skip: bool,
1759 rename: Option<String>,
1760 format_with: Option<String>,
1761 image: bool,
1762 include_only: bool, }
1764
1765fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
1767 let mut configs = Vec::new();
1768
1769 for attr in attrs {
1770 if attr.path().is_ident("prompt_for")
1771 && let Ok(meta_list) = attr.meta.require_list()
1772 {
1773 if meta_list.tokens.to_string() == "skip" {
1775 let config = FieldTargetConfig {
1777 skip: true,
1778 ..Default::default()
1779 };
1780 configs.push(("*".to_string(), config));
1781 } else if let Ok(metas) =
1782 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1783 {
1784 let mut target_name = None;
1785 let mut config = FieldTargetConfig::default();
1786
1787 for meta in metas {
1788 match meta {
1789 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1790 if let syn::Expr::Lit(syn::ExprLit {
1791 lit: syn::Lit::Str(lit_str),
1792 ..
1793 }) = nv.value
1794 {
1795 target_name = Some(lit_str.value());
1796 }
1797 }
1798 Meta::Path(path) if path.is_ident("skip") => {
1799 config.skip = true;
1800 }
1801 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
1802 if let syn::Expr::Lit(syn::ExprLit {
1803 lit: syn::Lit::Str(lit_str),
1804 ..
1805 }) = nv.value
1806 {
1807 config.rename = Some(lit_str.value());
1808 }
1809 }
1810 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
1811 if let syn::Expr::Lit(syn::ExprLit {
1812 lit: syn::Lit::Str(lit_str),
1813 ..
1814 }) = nv.value
1815 {
1816 config.format_with = Some(lit_str.value());
1817 }
1818 }
1819 Meta::Path(path) if path.is_ident("image") => {
1820 config.image = true;
1821 }
1822 _ => {}
1823 }
1824 }
1825
1826 if let Some(name) = target_name {
1827 config.include_only = true;
1828 configs.push((name, config));
1829 }
1830 }
1831 }
1832 }
1833
1834 configs
1835}
1836
1837fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
1839 let mut targets = Vec::new();
1840
1841 for attr in attrs {
1842 if attr.path().is_ident("prompt_for")
1843 && let Ok(meta_list) = attr.meta.require_list()
1844 && let Ok(metas) =
1845 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1846 {
1847 let mut target_name = None;
1848 let mut template = None;
1849
1850 for meta in metas {
1851 match meta {
1852 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1853 if let syn::Expr::Lit(syn::ExprLit {
1854 lit: syn::Lit::Str(lit_str),
1855 ..
1856 }) = nv.value
1857 {
1858 target_name = Some(lit_str.value());
1859 }
1860 }
1861 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1862 if let syn::Expr::Lit(syn::ExprLit {
1863 lit: syn::Lit::Str(lit_str),
1864 ..
1865 }) = nv.value
1866 {
1867 template = Some(lit_str.value());
1868 }
1869 }
1870 _ => {}
1871 }
1872 }
1873
1874 if let Some(name) = target_name {
1875 targets.push(TargetInfo {
1876 name,
1877 template,
1878 field_configs: std::collections::HashMap::new(),
1879 });
1880 }
1881 }
1882 }
1883
1884 targets
1885}
1886
1887#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
1888pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
1889 let input = parse_macro_input!(input as DeriveInput);
1890
1891 let found_crate =
1892 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1893 let crate_path = match found_crate {
1894 FoundCrate::Itself => {
1895 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1897 quote!(::#ident)
1898 }
1899 FoundCrate::Name(name) => {
1900 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1901 quote!(::#ident)
1902 }
1903 };
1904
1905 let data_struct = match &input.data {
1907 Data::Struct(data) => data,
1908 _ => {
1909 return syn::Error::new(
1910 input.ident.span(),
1911 "`#[derive(ToPromptSet)]` is only supported for structs",
1912 )
1913 .to_compile_error()
1914 .into();
1915 }
1916 };
1917
1918 let fields = match &data_struct.fields {
1919 syn::Fields::Named(fields) => &fields.named,
1920 _ => {
1921 return syn::Error::new(
1922 input.ident.span(),
1923 "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
1924 )
1925 .to_compile_error()
1926 .into();
1927 }
1928 };
1929
1930 let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
1932
1933 for field in fields.iter() {
1935 let field_name = field.ident.as_ref().unwrap().to_string();
1936 let field_configs = parse_prompt_for_attrs(&field.attrs);
1937
1938 for (target_name, config) in field_configs {
1939 if target_name == "*" {
1940 for target in &mut targets {
1942 target
1943 .field_configs
1944 .entry(field_name.clone())
1945 .or_insert_with(FieldTargetConfig::default)
1946 .skip = config.skip;
1947 }
1948 } else {
1949 let target_exists = targets.iter().any(|t| t.name == target_name);
1951 if !target_exists {
1952 targets.push(TargetInfo {
1954 name: target_name.clone(),
1955 template: None,
1956 field_configs: std::collections::HashMap::new(),
1957 });
1958 }
1959
1960 let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
1961
1962 target.field_configs.insert(field_name.clone(), config);
1963 }
1964 }
1965 }
1966
1967 let mut match_arms = Vec::new();
1969
1970 for target in &targets {
1971 let target_name = &target.name;
1972
1973 if let Some(template_str) = &target.template {
1974 let mut image_parts = Vec::new();
1976
1977 for field in fields.iter() {
1978 let field_name = field.ident.as_ref().unwrap();
1979 let field_name_str = field_name.to_string();
1980
1981 if let Some(config) = target.field_configs.get(&field_name_str)
1982 && config.image
1983 {
1984 image_parts.push(quote! {
1985 parts.extend(self.#field_name.to_prompt_parts());
1986 });
1987 }
1988 }
1989
1990 match_arms.push(quote! {
1991 #target_name => {
1992 let mut parts = Vec::new();
1993
1994 #(#image_parts)*
1995
1996 let text = #crate_path::prompt::render_prompt(#template_str, self)
1997 .map_err(|e| #crate_path::prompt::PromptSetError::RenderFailed {
1998 target: #target_name.to_string(),
1999 source: e,
2000 })?;
2001
2002 if !text.is_empty() {
2003 parts.push(#crate_path::prompt::PromptPart::Text(text));
2004 }
2005
2006 Ok(parts)
2007 }
2008 });
2009 } else {
2010 let mut text_field_parts = Vec::new();
2012 let mut image_field_parts = Vec::new();
2013
2014 for field in fields.iter() {
2015 let field_name = field.ident.as_ref().unwrap();
2016 let field_name_str = field_name.to_string();
2017
2018 let config = target.field_configs.get(&field_name_str);
2020
2021 if let Some(cfg) = config
2023 && cfg.skip
2024 {
2025 continue;
2026 }
2027
2028 let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
2032 let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
2033 .iter()
2034 .any(|(name, _)| name != "*");
2035
2036 if has_any_target_specific_config && !is_explicitly_for_this_target {
2037 continue;
2038 }
2039
2040 if let Some(cfg) = config {
2041 if cfg.image {
2042 image_field_parts.push(quote! {
2043 parts.extend(self.#field_name.to_prompt_parts());
2044 });
2045 } else {
2046 let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
2047
2048 let value_expr = if let Some(format_with) = &cfg.format_with {
2049 match syn::parse_str::<syn::Path>(format_with) {
2051 Ok(func_path) => quote! { #func_path(&self.#field_name) },
2052 Err(_) => {
2053 let error_msg = format!(
2055 "Invalid function path in format_with: '{}'",
2056 format_with
2057 );
2058 quote! {
2059 compile_error!(#error_msg);
2060 String::new()
2061 }
2062 }
2063 }
2064 } else {
2065 quote! { self.#field_name.to_prompt() }
2066 };
2067
2068 text_field_parts.push(quote! {
2069 text_parts.push(format!("{}: {}", #key, #value_expr));
2070 });
2071 }
2072 } else {
2073 text_field_parts.push(quote! {
2075 text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
2076 });
2077 }
2078 }
2079
2080 match_arms.push(quote! {
2081 #target_name => {
2082 let mut parts = Vec::new();
2083
2084 #(#image_field_parts)*
2085
2086 let mut text_parts = Vec::new();
2087 #(#text_field_parts)*
2088
2089 if !text_parts.is_empty() {
2090 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
2091 }
2092
2093 Ok(parts)
2094 }
2095 });
2096 }
2097 }
2098
2099 let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
2101
2102 match_arms.push(quote! {
2104 _ => {
2105 let available = vec![#(#target_names.to_string()),*];
2106 Err(#crate_path::prompt::PromptSetError::TargetNotFound {
2107 target: target.to_string(),
2108 available,
2109 })
2110 }
2111 });
2112
2113 let struct_name = &input.ident;
2114 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2115
2116 let expanded = quote! {
2117 impl #impl_generics #crate_path::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
2118 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<#crate_path::prompt::PromptPart>, #crate_path::prompt::PromptSetError> {
2119 match target {
2120 #(#match_arms)*
2121 }
2122 }
2123 }
2124 };
2125
2126 TokenStream::from(expanded)
2127}
2128
2129struct TypeList {
2131 types: Punctuated<syn::Type, Token![,]>,
2132}
2133
2134impl Parse for TypeList {
2135 fn parse(input: ParseStream) -> syn::Result<Self> {
2136 Ok(TypeList {
2137 types: Punctuated::parse_terminated(input)?,
2138 })
2139 }
2140}
2141
2142#[proc_macro]
2166pub fn examples_section(input: TokenStream) -> TokenStream {
2167 let input = parse_macro_input!(input as TypeList);
2168
2169 let found_crate =
2170 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2171 let _crate_path = match found_crate {
2172 FoundCrate::Itself => quote!(crate),
2173 FoundCrate::Name(name) => {
2174 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2175 quote!(::#ident)
2176 }
2177 };
2178
2179 let mut type_sections = Vec::new();
2181
2182 for ty in input.types.iter() {
2183 let type_name_str = quote!(#ty).to_string();
2185
2186 type_sections.push(quote! {
2188 {
2189 let type_name = #type_name_str;
2190 let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
2191 format!("---\n#### `{}`\n{}", type_name, json_example)
2192 }
2193 });
2194 }
2195
2196 let expanded = quote! {
2198 {
2199 let mut sections = Vec::new();
2200 sections.push("---".to_string());
2201 sections.push("### Examples".to_string());
2202 sections.push("".to_string());
2203 sections.push("Here are examples of the data structures you should use.".to_string());
2204 sections.push("".to_string());
2205
2206 #(sections.push(#type_sections);)*
2207
2208 sections.push("---".to_string());
2209
2210 sections.join("\n")
2211 }
2212 };
2213
2214 TokenStream::from(expanded)
2215}
2216
2217fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
2219 for attr in attrs {
2220 if attr.path().is_ident("prompt_for")
2221 && let Ok(meta_list) = attr.meta.require_list()
2222 && let Ok(metas) =
2223 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2224 {
2225 let mut target_type = None;
2226 let mut template = None;
2227
2228 for meta in metas {
2229 match meta {
2230 Meta::NameValue(nv) if nv.path.is_ident("target") => {
2231 if let syn::Expr::Lit(syn::ExprLit {
2232 lit: syn::Lit::Str(lit_str),
2233 ..
2234 }) = nv.value
2235 {
2236 target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
2238 }
2239 }
2240 Meta::NameValue(nv) if nv.path.is_ident("template") => {
2241 if let syn::Expr::Lit(syn::ExprLit {
2242 lit: syn::Lit::Str(lit_str),
2243 ..
2244 }) = nv.value
2245 {
2246 template = Some(lit_str.value());
2247 }
2248 }
2249 _ => {}
2250 }
2251 }
2252
2253 if let (Some(target), Some(tmpl)) = (target_type, template) {
2254 return (target, tmpl);
2255 }
2256 }
2257 }
2258
2259 panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
2260}
2261
2262#[proc_macro_attribute]
2296pub fn define_intent(_attr: TokenStream, item: TokenStream) -> TokenStream {
2297 let input = parse_macro_input!(item as DeriveInput);
2298
2299 let found_crate =
2300 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2301 let crate_path = match found_crate {
2302 FoundCrate::Itself => {
2303 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2305 quote!(::#ident)
2306 }
2307 FoundCrate::Name(name) => {
2308 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2309 quote!(::#ident)
2310 }
2311 };
2312
2313 let enum_data = match &input.data {
2315 Data::Enum(data) => data,
2316 _ => {
2317 return syn::Error::new(
2318 input.ident.span(),
2319 "`#[define_intent]` can only be applied to enums",
2320 )
2321 .to_compile_error()
2322 .into();
2323 }
2324 };
2325
2326 let mut prompt_template = None;
2328 let mut extractor_tag = None;
2329 let mut mode = None;
2330
2331 for attr in &input.attrs {
2332 if attr.path().is_ident("intent")
2333 && let Ok(metas) =
2334 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2335 {
2336 for meta in metas {
2337 match meta {
2338 Meta::NameValue(nv) if nv.path.is_ident("prompt") => {
2339 if let syn::Expr::Lit(syn::ExprLit {
2340 lit: syn::Lit::Str(lit_str),
2341 ..
2342 }) = nv.value
2343 {
2344 prompt_template = Some(lit_str.value());
2345 }
2346 }
2347 Meta::NameValue(nv) if nv.path.is_ident("extractor_tag") => {
2348 if let syn::Expr::Lit(syn::ExprLit {
2349 lit: syn::Lit::Str(lit_str),
2350 ..
2351 }) = nv.value
2352 {
2353 extractor_tag = Some(lit_str.value());
2354 }
2355 }
2356 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
2357 if let syn::Expr::Lit(syn::ExprLit {
2358 lit: syn::Lit::Str(lit_str),
2359 ..
2360 }) = nv.value
2361 {
2362 mode = Some(lit_str.value());
2363 }
2364 }
2365 _ => {}
2366 }
2367 }
2368 }
2369 }
2370
2371 let mode = mode.unwrap_or_else(|| "single".to_string());
2373
2374 if mode != "single" && mode != "multi_tag" {
2376 return syn::Error::new(
2377 input.ident.span(),
2378 "`mode` must be either \"single\" or \"multi_tag\"",
2379 )
2380 .to_compile_error()
2381 .into();
2382 }
2383
2384 let prompt_template = match prompt_template {
2386 Some(p) => p,
2387 None => {
2388 return syn::Error::new(
2389 input.ident.span(),
2390 "`#[intent(...)]` attribute must include `prompt = \"...\"`",
2391 )
2392 .to_compile_error()
2393 .into();
2394 }
2395 };
2396
2397 if mode == "multi_tag" {
2399 let enum_name = &input.ident;
2400 let actions_doc = generate_multi_tag_actions_doc(&enum_data.variants);
2401 return generate_multi_tag_output(
2402 &input,
2403 enum_name,
2404 enum_data,
2405 prompt_template,
2406 actions_doc,
2407 );
2408 }
2409
2410 let extractor_tag = match extractor_tag {
2412 Some(t) => t,
2413 None => {
2414 return syn::Error::new(
2415 input.ident.span(),
2416 "`#[intent(...)]` attribute must include `extractor_tag = \"...\"`",
2417 )
2418 .to_compile_error()
2419 .into();
2420 }
2421 };
2422
2423 let enum_name = &input.ident;
2425 let enum_docs = extract_doc_comments(&input.attrs);
2426
2427 let mut intents_doc_lines = Vec::new();
2428
2429 if !enum_docs.is_empty() {
2431 intents_doc_lines.push(format!("{}: {}", enum_name, enum_docs));
2432 } else {
2433 intents_doc_lines.push(format!("{}:", enum_name));
2434 }
2435 intents_doc_lines.push(String::new()); intents_doc_lines.push("Possible values:".to_string());
2437
2438 for variant in &enum_data.variants {
2440 let variant_name = &variant.ident;
2441 let variant_docs = extract_doc_comments(&variant.attrs);
2442
2443 if !variant_docs.is_empty() {
2444 intents_doc_lines.push(format!("- {}: {}", variant_name, variant_docs));
2445 } else {
2446 intents_doc_lines.push(format!("- {}", variant_name));
2447 }
2448 }
2449
2450 let intents_doc_str = intents_doc_lines.join("\n");
2451
2452 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2454 let user_variables: Vec<String> = placeholders
2455 .iter()
2456 .filter_map(|(name, _)| {
2457 if name != "intents_doc" {
2458 Some(name.clone())
2459 } else {
2460 None
2461 }
2462 })
2463 .collect();
2464
2465 let enum_name_str = enum_name.to_string();
2467 let snake_case_name = to_snake_case(&enum_name_str);
2468 let function_name = syn::Ident::new(
2469 &format!("build_{}_prompt", snake_case_name),
2470 proc_macro2::Span::call_site(),
2471 );
2472
2473 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2475 .iter()
2476 .map(|var| {
2477 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2478 quote! { #ident: &str }
2479 })
2480 .collect();
2481
2482 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2484 .iter()
2485 .map(|var| {
2486 let var_str = var.clone();
2487 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2488 quote! {
2489 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2490 }
2491 })
2492 .collect();
2493
2494 let converted_template = prompt_template.clone();
2496
2497 let extractor_name = syn::Ident::new(
2499 &format!("{}Extractor", enum_name),
2500 proc_macro2::Span::call_site(),
2501 );
2502
2503 let filtered_attrs: Vec<_> = input
2505 .attrs
2506 .iter()
2507 .filter(|attr| !attr.path().is_ident("intent"))
2508 .collect();
2509
2510 let vis = &input.vis;
2512 let generics = &input.generics;
2513 let variants = &enum_data.variants;
2514 let enum_output = quote! {
2515 #(#filtered_attrs)*
2516 #vis enum #enum_name #generics {
2517 #variants
2518 }
2519 };
2520
2521 let expanded = quote! {
2523 #enum_output
2525
2526 pub fn #function_name(#(#function_params),*) -> String {
2528 let mut env = minijinja::Environment::new();
2529 env.add_template("prompt", #converted_template)
2530 .expect("Failed to parse intent prompt template");
2531
2532 let tmpl = env.get_template("prompt").unwrap();
2533
2534 let mut __template_context = std::collections::HashMap::new();
2535
2536 __template_context.insert("intents_doc".to_string(), minijinja::Value::from(#intents_doc_str));
2538
2539 #(#context_insertions)*
2541
2542 tmpl.render(&__template_context)
2543 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2544 }
2545
2546 pub struct #extractor_name;
2548
2549 impl #extractor_name {
2550 pub const EXTRACTOR_TAG: &'static str = #extractor_tag;
2551 }
2552
2553 impl #crate_path::intent::IntentExtractor<#enum_name> for #extractor_name {
2554 fn extract_intent(&self, response: &str) -> Result<#enum_name, #crate_path::intent::IntentExtractionError> {
2555 #crate_path::intent::extract_intent_from_response(response, Self::EXTRACTOR_TAG)
2557 }
2558 }
2559 };
2560
2561 TokenStream::from(expanded)
2562}
2563
2564fn to_snake_case(s: &str) -> String {
2566 let mut result = String::new();
2567 let mut prev_upper = false;
2568
2569 for (i, ch) in s.chars().enumerate() {
2570 if ch.is_uppercase() {
2571 if i > 0 && !prev_upper {
2572 result.push('_');
2573 }
2574 result.push(ch.to_lowercase().next().unwrap());
2575 prev_upper = true;
2576 } else {
2577 result.push(ch);
2578 prev_upper = false;
2579 }
2580 }
2581
2582 result
2583}
2584
2585#[derive(Debug, Default)]
2587struct ActionAttrs {
2588 tag: Option<String>,
2589}
2590
2591fn parse_action_attrs(attrs: &[syn::Attribute]) -> ActionAttrs {
2592 let mut result = ActionAttrs::default();
2593
2594 for attr in attrs {
2595 if attr.path().is_ident("action")
2596 && let Ok(meta_list) = attr.meta.require_list()
2597 && let Ok(metas) =
2598 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2599 {
2600 for meta in metas {
2601 if let Meta::NameValue(nv) = meta
2602 && nv.path.is_ident("tag")
2603 && let syn::Expr::Lit(syn::ExprLit {
2604 lit: syn::Lit::Str(lit_str),
2605 ..
2606 }) = nv.value
2607 {
2608 result.tag = Some(lit_str.value());
2609 }
2610 }
2611 }
2612 }
2613
2614 result
2615}
2616
2617#[derive(Debug, Default)]
2619struct FieldActionAttrs {
2620 is_attribute: bool,
2621 is_inner_text: bool,
2622}
2623
2624fn parse_field_action_attrs(attrs: &[syn::Attribute]) -> FieldActionAttrs {
2625 let mut result = FieldActionAttrs::default();
2626
2627 for attr in attrs {
2628 if attr.path().is_ident("action")
2629 && let Ok(meta_list) = attr.meta.require_list()
2630 {
2631 let tokens_str = meta_list.tokens.to_string();
2632 if tokens_str == "attribute" {
2633 result.is_attribute = true;
2634 } else if tokens_str == "inner_text" {
2635 result.is_inner_text = true;
2636 }
2637 }
2638 }
2639
2640 result
2641}
2642
2643fn generate_multi_tag_actions_doc(
2645 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2646) -> String {
2647 let mut doc_lines = Vec::new();
2648
2649 for variant in variants {
2650 let action_attrs = parse_action_attrs(&variant.attrs);
2651
2652 if let Some(tag) = action_attrs.tag {
2653 let variant_docs = extract_doc_comments(&variant.attrs);
2654
2655 match &variant.fields {
2656 syn::Fields::Unit => {
2657 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2659 }
2660 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
2661 doc_lines.push(format!("- `<{}>...</{}>`: {}", tag, tag, variant_docs));
2663 }
2664 syn::Fields::Named(fields) => {
2665 let mut attrs_str = Vec::new();
2667 let mut has_inner_text = false;
2668
2669 for field in &fields.named {
2670 let field_name = field.ident.as_ref().unwrap();
2671 let field_attrs = parse_field_action_attrs(&field.attrs);
2672
2673 if field_attrs.is_attribute {
2674 attrs_str.push(format!("{}=\"...\"", field_name));
2675 } else if field_attrs.is_inner_text {
2676 has_inner_text = true;
2677 }
2678 }
2679
2680 let attrs_part = if !attrs_str.is_empty() {
2681 format!(" {}", attrs_str.join(" "))
2682 } else {
2683 String::new()
2684 };
2685
2686 if has_inner_text {
2687 doc_lines.push(format!(
2688 "- `<{}{}>...</{}>`: {}",
2689 tag, attrs_part, tag, variant_docs
2690 ));
2691 } else if !attrs_str.is_empty() {
2692 doc_lines.push(format!("- `<{}{} />`: {}", tag, attrs_part, variant_docs));
2693 } else {
2694 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2695 }
2696
2697 for field in &fields.named {
2699 let field_name = field.ident.as_ref().unwrap();
2700 let field_attrs = parse_field_action_attrs(&field.attrs);
2701 let field_docs = extract_doc_comments(&field.attrs);
2702
2703 if field_attrs.is_attribute {
2704 doc_lines
2705 .push(format!(" - `{}` (attribute): {}", field_name, field_docs));
2706 } else if field_attrs.is_inner_text {
2707 doc_lines
2708 .push(format!(" - `{}` (inner_text): {}", field_name, field_docs));
2709 }
2710 }
2711 }
2712 _ => {
2713 }
2715 }
2716 }
2717 }
2718
2719 doc_lines.join("\n")
2720}
2721
2722fn generate_tags_regex(
2724 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2725) -> String {
2726 let mut tag_names = Vec::new();
2727
2728 for variant in variants {
2729 let action_attrs = parse_action_attrs(&variant.attrs);
2730 if let Some(tag) = action_attrs.tag {
2731 tag_names.push(tag);
2732 }
2733 }
2734
2735 if tag_names.is_empty() {
2736 return String::new();
2737 }
2738
2739 let tags_pattern = tag_names.join("|");
2740 format!(
2743 r"(?is)<(?:{})\b[^>]*/>|<(?:{})\b[^>]*>.*?</(?:{})>",
2744 tags_pattern, tags_pattern, tags_pattern
2745 )
2746}
2747
2748fn generate_multi_tag_output(
2750 input: &DeriveInput,
2751 enum_name: &syn::Ident,
2752 enum_data: &syn::DataEnum,
2753 prompt_template: String,
2754 actions_doc: String,
2755) -> TokenStream {
2756 let found_crate =
2757 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2758 let crate_path = match found_crate {
2759 FoundCrate::Itself => {
2760 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2762 quote!(::#ident)
2763 }
2764 FoundCrate::Name(name) => {
2765 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2766 quote!(::#ident)
2767 }
2768 };
2769
2770 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2772 let user_variables: Vec<String> = placeholders
2773 .iter()
2774 .filter_map(|(name, _)| {
2775 if name != "actions_doc" {
2776 Some(name.clone())
2777 } else {
2778 None
2779 }
2780 })
2781 .collect();
2782
2783 let enum_name_str = enum_name.to_string();
2785 let snake_case_name = to_snake_case(&enum_name_str);
2786 let function_name = syn::Ident::new(
2787 &format!("build_{}_prompt", snake_case_name),
2788 proc_macro2::Span::call_site(),
2789 );
2790
2791 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2793 .iter()
2794 .map(|var| {
2795 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2796 quote! { #ident: &str }
2797 })
2798 .collect();
2799
2800 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2802 .iter()
2803 .map(|var| {
2804 let var_str = var.clone();
2805 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2806 quote! {
2807 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2808 }
2809 })
2810 .collect();
2811
2812 let extractor_name = syn::Ident::new(
2814 &format!("{}Extractor", enum_name),
2815 proc_macro2::Span::call_site(),
2816 );
2817
2818 let filtered_attrs: Vec<_> = input
2820 .attrs
2821 .iter()
2822 .filter(|attr| !attr.path().is_ident("intent"))
2823 .collect();
2824
2825 let filtered_variants: Vec<proc_macro2::TokenStream> = enum_data
2827 .variants
2828 .iter()
2829 .map(|variant| {
2830 let variant_name = &variant.ident;
2831 let variant_attrs: Vec<_> = variant
2832 .attrs
2833 .iter()
2834 .filter(|attr| !attr.path().is_ident("action"))
2835 .collect();
2836 let fields = &variant.fields;
2837
2838 let filtered_fields = match fields {
2840 syn::Fields::Named(named_fields) => {
2841 let filtered: Vec<_> = named_fields
2842 .named
2843 .iter()
2844 .map(|field| {
2845 let field_name = &field.ident;
2846 let field_type = &field.ty;
2847 let field_vis = &field.vis;
2848 let filtered_attrs: Vec<_> = field
2849 .attrs
2850 .iter()
2851 .filter(|attr| !attr.path().is_ident("action"))
2852 .collect();
2853 quote! {
2854 #(#filtered_attrs)*
2855 #field_vis #field_name: #field_type
2856 }
2857 })
2858 .collect();
2859 quote! { { #(#filtered,)* } }
2860 }
2861 syn::Fields::Unnamed(unnamed_fields) => {
2862 let types: Vec<_> = unnamed_fields
2863 .unnamed
2864 .iter()
2865 .map(|field| {
2866 let field_type = &field.ty;
2867 quote! { #field_type }
2868 })
2869 .collect();
2870 quote! { (#(#types),*) }
2871 }
2872 syn::Fields::Unit => quote! {},
2873 };
2874
2875 quote! {
2876 #(#variant_attrs)*
2877 #variant_name #filtered_fields
2878 }
2879 })
2880 .collect();
2881
2882 let vis = &input.vis;
2883 let generics = &input.generics;
2884
2885 let parsing_arms = generate_parsing_arms(&enum_data.variants, enum_name);
2887
2888 let tags_regex = generate_tags_regex(&enum_data.variants);
2890
2891 let expanded = quote! {
2892 #(#filtered_attrs)*
2894 #vis enum #enum_name #generics {
2895 #(#filtered_variants),*
2896 }
2897
2898 pub fn #function_name(#(#function_params),*) -> String {
2900 let mut env = minijinja::Environment::new();
2901 env.add_template("prompt", #prompt_template)
2902 .expect("Failed to parse intent prompt template");
2903
2904 let tmpl = env.get_template("prompt").unwrap();
2905
2906 let mut __template_context = std::collections::HashMap::new();
2907
2908 __template_context.insert("actions_doc".to_string(), minijinja::Value::from(#actions_doc));
2910
2911 #(#context_insertions)*
2913
2914 tmpl.render(&__template_context)
2915 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2916 }
2917
2918 pub struct #extractor_name;
2920
2921 impl #extractor_name {
2922 fn parse_single_action(&self, text: &str) -> Option<#enum_name> {
2923 use ::quick_xml::events::Event;
2924 use ::quick_xml::Reader;
2925
2926 let mut actions = Vec::new();
2927 let mut reader = Reader::from_str(text);
2928 reader.config_mut().trim_text(true);
2929
2930 let mut buf = Vec::new();
2931
2932 loop {
2933 match reader.read_event_into(&mut buf) {
2934 Ok(Event::Start(e)) => {
2935 let owned_e = e.into_owned();
2936 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2937 let is_empty = false;
2938
2939 #parsing_arms
2940 }
2941 Ok(Event::Empty(e)) => {
2942 let owned_e = e.into_owned();
2943 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2944 let is_empty = true;
2945
2946 #parsing_arms
2947 }
2948 Ok(Event::Eof) => break,
2949 Err(_) => {
2950 break;
2952 }
2953 _ => {}
2954 }
2955 buf.clear();
2956 }
2957
2958 actions.into_iter().next()
2959 }
2960
2961 pub fn extract_actions(&self, text: &str) -> Result<Vec<#enum_name>, #crate_path::intent::IntentError> {
2962 use ::quick_xml::events::Event;
2963 use ::quick_xml::Reader;
2964
2965 let mut actions = Vec::new();
2966 let mut reader = Reader::from_str(text);
2967 reader.config_mut().trim_text(true);
2968
2969 let mut buf = Vec::new();
2970
2971 loop {
2972 match reader.read_event_into(&mut buf) {
2973 Ok(Event::Start(e)) => {
2974 let owned_e = e.into_owned();
2975 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2976 let is_empty = false;
2977
2978 #parsing_arms
2979 }
2980 Ok(Event::Empty(e)) => {
2981 let owned_e = e.into_owned();
2982 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2983 let is_empty = true;
2984
2985 #parsing_arms
2986 }
2987 Ok(Event::Eof) => break,
2988 Err(_) => {
2989 break;
2991 }
2992 _ => {}
2993 }
2994 buf.clear();
2995 }
2996
2997 Ok(actions)
2998 }
2999
3000 pub fn transform_actions<F>(&self, text: &str, mut transformer: F) -> String
3001 where
3002 F: FnMut(#enum_name) -> String,
3003 {
3004 use ::regex::Regex;
3005
3006 let regex_pattern = #tags_regex;
3007 if regex_pattern.is_empty() {
3008 return text.to_string();
3009 }
3010
3011 let re = Regex::new(®ex_pattern).unwrap_or_else(|e| {
3012 panic!("Failed to compile regex for action tags: {}", e);
3013 });
3014
3015 re.replace_all(text, |caps: &::regex::Captures| {
3016 let matched = caps.get(0).map(|m| m.as_str()).unwrap_or("");
3017
3018 if let Some(action) = self.parse_single_action(matched) {
3020 transformer(action)
3021 } else {
3022 matched.to_string()
3024 }
3025 }).to_string()
3026 }
3027
3028 pub fn strip_actions(&self, text: &str) -> String {
3029 self.transform_actions(text, |_| String::new())
3030 }
3031 }
3032 };
3033
3034 TokenStream::from(expanded)
3035}
3036
3037fn generate_parsing_arms(
3039 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
3040 enum_name: &syn::Ident,
3041) -> proc_macro2::TokenStream {
3042 let mut arms = Vec::new();
3043
3044 for variant in variants {
3045 let variant_name = &variant.ident;
3046 let action_attrs = parse_action_attrs(&variant.attrs);
3047
3048 if let Some(tag) = action_attrs.tag {
3049 match &variant.fields {
3050 syn::Fields::Unit => {
3051 arms.push(quote! {
3053 if &tag_name == #tag {
3054 actions.push(#enum_name::#variant_name);
3055 }
3056 });
3057 }
3058 syn::Fields::Unnamed(_fields) => {
3059 arms.push(quote! {
3061 if &tag_name == #tag && !is_empty {
3062 match reader.read_text(owned_e.name()) {
3064 Ok(text) => {
3065 actions.push(#enum_name::#variant_name(text.to_string()));
3066 }
3067 Err(_) => {
3068 actions.push(#enum_name::#variant_name(String::new()));
3070 }
3071 }
3072 }
3073 });
3074 }
3075 syn::Fields::Named(fields) => {
3076 let mut field_names = Vec::new();
3078 let mut has_inner_text_field = None;
3079
3080 for field in &fields.named {
3081 let field_name = field.ident.as_ref().unwrap();
3082 let field_attrs = parse_field_action_attrs(&field.attrs);
3083
3084 if field_attrs.is_attribute {
3085 field_names.push(field_name.clone());
3086 } else if field_attrs.is_inner_text {
3087 has_inner_text_field = Some(field_name.clone());
3088 }
3089 }
3090
3091 if let Some(inner_text_field) = has_inner_text_field {
3092 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
3095 quote! {
3096 let mut #field_name = String::new();
3097 for attr in owned_e.attributes() {
3098 if let Ok(attr) = attr {
3099 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
3100 #field_name = String::from_utf8_lossy(&attr.value).to_string();
3101 break;
3102 }
3103 }
3104 }
3105 }
3106 }).collect();
3107
3108 arms.push(quote! {
3109 if &tag_name == #tag {
3110 #(#attr_extractions)*
3111
3112 if is_empty {
3114 let #inner_text_field = String::new();
3115 actions.push(#enum_name::#variant_name {
3116 #(#field_names,)*
3117 #inner_text_field,
3118 });
3119 } else {
3120 match reader.read_text(owned_e.name()) {
3122 Ok(text) => {
3123 let #inner_text_field = text.to_string();
3124 actions.push(#enum_name::#variant_name {
3125 #(#field_names,)*
3126 #inner_text_field,
3127 });
3128 }
3129 Err(_) => {
3130 let #inner_text_field = String::new();
3132 actions.push(#enum_name::#variant_name {
3133 #(#field_names,)*
3134 #inner_text_field,
3135 });
3136 }
3137 }
3138 }
3139 }
3140 });
3141 } else {
3142 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
3144 quote! {
3145 let mut #field_name = String::new();
3146 for attr in owned_e.attributes() {
3147 if let Ok(attr) = attr {
3148 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
3149 #field_name = String::from_utf8_lossy(&attr.value).to_string();
3150 break;
3151 }
3152 }
3153 }
3154 }
3155 }).collect();
3156
3157 arms.push(quote! {
3158 if &tag_name == #tag {
3159 #(#attr_extractions)*
3160 actions.push(#enum_name::#variant_name {
3161 #(#field_names),*
3162 });
3163 }
3164 });
3165 }
3166 }
3167 }
3168 }
3169 }
3170
3171 quote! {
3172 #(#arms)*
3173 }
3174}
3175
3176#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
3178pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
3179 let input = parse_macro_input!(input as DeriveInput);
3180
3181 let found_crate =
3182 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3183 let crate_path = match found_crate {
3184 FoundCrate::Itself => {
3185 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3187 quote!(::#ident)
3188 }
3189 FoundCrate::Name(name) => {
3190 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3191 quote!(::#ident)
3192 }
3193 };
3194
3195 let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
3197
3198 let struct_name = &input.ident;
3199 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
3200
3201 let placeholders = parse_template_placeholders_with_mode(&template);
3203
3204 let mut converted_template = template.clone();
3206 let mut context_fields = Vec::new();
3207
3208 let fields = match &input.data {
3210 Data::Struct(data_struct) => match &data_struct.fields {
3211 syn::Fields::Named(fields) => &fields.named,
3212 _ => panic!("ToPromptFor is only supported for structs with named fields"),
3213 },
3214 _ => panic!("ToPromptFor is only supported for structs"),
3215 };
3216
3217 let has_mode_support = input.attrs.iter().any(|attr| {
3219 if attr.path().is_ident("prompt")
3220 && let Ok(metas) =
3221 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
3222 {
3223 for meta in metas {
3224 if let Meta::NameValue(nv) = meta
3225 && nv.path.is_ident("mode")
3226 {
3227 return true;
3228 }
3229 }
3230 }
3231 false
3232 });
3233
3234 for (placeholder_name, mode_opt) in &placeholders {
3236 if placeholder_name == "self" {
3237 if let Some(specific_mode) = mode_opt {
3238 let unique_key = format!("self__{}", specific_mode);
3240
3241 let pattern = format!("{{{{ self:{} }}}}", specific_mode);
3243 let replacement = format!("{{{{ {} }}}}", unique_key);
3244 converted_template = converted_template.replace(&pattern, &replacement);
3245
3246 context_fields.push(quote! {
3248 context.insert(
3249 #unique_key.to_string(),
3250 minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
3251 );
3252 });
3253 } else {
3254 if has_mode_support {
3257 context_fields.push(quote! {
3259 context.insert(
3260 "self".to_string(),
3261 minijinja::Value::from(self.to_prompt_with_mode(mode))
3262 );
3263 });
3264 } else {
3265 context_fields.push(quote! {
3267 context.insert(
3268 "self".to_string(),
3269 minijinja::Value::from(self.to_prompt())
3270 );
3271 });
3272 }
3273 }
3274 } else {
3275 let field_exists = fields.iter().any(|f| {
3278 f.ident
3279 .as_ref()
3280 .is_some_and(|ident| ident == placeholder_name)
3281 });
3282
3283 if field_exists {
3284 let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
3285
3286 context_fields.push(quote! {
3290 context.insert(
3291 #placeholder_name.to_string(),
3292 minijinja::Value::from_serialize(&self.#field_ident)
3293 );
3294 });
3295 }
3296 }
3298 }
3299
3300 let expanded = quote! {
3301 impl #impl_generics #crate_path::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
3302 where
3303 #target_type: serde::Serialize,
3304 {
3305 fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
3306 let mut env = minijinja::Environment::new();
3308 env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
3309 panic!("Failed to parse template: {}", e)
3310 });
3311
3312 let tmpl = env.get_template("prompt").unwrap();
3313
3314 let mut context = std::collections::HashMap::new();
3316 context.insert(
3318 "self".to_string(),
3319 minijinja::Value::from_serialize(self)
3320 );
3321 context.insert(
3323 "target".to_string(),
3324 minijinja::Value::from_serialize(target)
3325 );
3326 #(#context_fields)*
3327
3328 tmpl.render(context).unwrap_or_else(|e| {
3330 format!("Failed to render prompt: {}", e)
3331 })
3332 }
3333 }
3334 };
3335
3336 TokenStream::from(expanded)
3337}
3338
3339struct AgentAttrs {
3345 expertise: Option<String>,
3346 output: Option<syn::Type>,
3347 backend: Option<String>,
3348 model: Option<String>,
3349 inner: Option<String>,
3350 default_inner: Option<String>,
3351 max_retries: Option<u32>,
3352 profile: Option<String>,
3353}
3354
3355impl Parse for AgentAttrs {
3356 fn parse(input: ParseStream) -> syn::Result<Self> {
3357 let mut expertise = None;
3358 let mut output = None;
3359 let mut backend = None;
3360 let mut model = None;
3361 let mut inner = None;
3362 let mut default_inner = None;
3363 let mut max_retries = None;
3364 let mut profile = None;
3365
3366 let pairs = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
3367
3368 for meta in pairs {
3369 match meta {
3370 Meta::NameValue(nv) if nv.path.is_ident("expertise") => {
3371 if let syn::Expr::Lit(syn::ExprLit {
3372 lit: syn::Lit::Str(lit_str),
3373 ..
3374 }) = &nv.value
3375 {
3376 expertise = Some(lit_str.value());
3377 }
3378 }
3379 Meta::NameValue(nv) if nv.path.is_ident("output") => {
3380 if let syn::Expr::Lit(syn::ExprLit {
3381 lit: syn::Lit::Str(lit_str),
3382 ..
3383 }) = &nv.value
3384 {
3385 let ty: syn::Type = syn::parse_str(&lit_str.value())?;
3386 output = Some(ty);
3387 }
3388 }
3389 Meta::NameValue(nv) if nv.path.is_ident("backend") => {
3390 if let syn::Expr::Lit(syn::ExprLit {
3391 lit: syn::Lit::Str(lit_str),
3392 ..
3393 }) = &nv.value
3394 {
3395 backend = Some(lit_str.value());
3396 }
3397 }
3398 Meta::NameValue(nv) if nv.path.is_ident("model") => {
3399 if let syn::Expr::Lit(syn::ExprLit {
3400 lit: syn::Lit::Str(lit_str),
3401 ..
3402 }) = &nv.value
3403 {
3404 model = Some(lit_str.value());
3405 }
3406 }
3407 Meta::NameValue(nv) if nv.path.is_ident("inner") => {
3408 if let syn::Expr::Lit(syn::ExprLit {
3409 lit: syn::Lit::Str(lit_str),
3410 ..
3411 }) = &nv.value
3412 {
3413 inner = Some(lit_str.value());
3414 }
3415 }
3416 Meta::NameValue(nv) if nv.path.is_ident("default_inner") => {
3417 if let syn::Expr::Lit(syn::ExprLit {
3418 lit: syn::Lit::Str(lit_str),
3419 ..
3420 }) = &nv.value
3421 {
3422 default_inner = Some(lit_str.value());
3423 }
3424 }
3425 Meta::NameValue(nv) if nv.path.is_ident("max_retries") => {
3426 if let syn::Expr::Lit(syn::ExprLit {
3427 lit: syn::Lit::Int(lit_int),
3428 ..
3429 }) = &nv.value
3430 {
3431 max_retries = Some(lit_int.base10_parse()?);
3432 }
3433 }
3434 Meta::NameValue(nv) if nv.path.is_ident("profile") => {
3435 if let syn::Expr::Lit(syn::ExprLit {
3436 lit: syn::Lit::Str(lit_str),
3437 ..
3438 }) = &nv.value
3439 {
3440 profile = Some(lit_str.value());
3441 }
3442 }
3443 _ => {}
3444 }
3445 }
3446
3447 Ok(AgentAttrs {
3448 expertise,
3449 output,
3450 backend,
3451 model,
3452 inner,
3453 default_inner,
3454 max_retries,
3455 profile,
3456 })
3457 }
3458}
3459
3460fn parse_agent_attrs(attrs: &[syn::Attribute]) -> syn::Result<AgentAttrs> {
3462 for attr in attrs {
3463 if attr.path().is_ident("agent") {
3464 return attr.parse_args::<AgentAttrs>();
3465 }
3466 }
3467
3468 Ok(AgentAttrs {
3469 expertise: None,
3470 output: None,
3471 backend: None,
3472 model: None,
3473 inner: None,
3474 default_inner: None,
3475 max_retries: None,
3476 profile: None,
3477 })
3478}
3479
3480fn generate_backend_constructors(
3482 struct_name: &syn::Ident,
3483 backend: &str,
3484 _model: Option<&str>,
3485 _profile: Option<&str>,
3486 crate_path: &proc_macro2::TokenStream,
3487) -> proc_macro2::TokenStream {
3488 match backend {
3489 "claude" => {
3490 quote! {
3491 impl #struct_name {
3492 pub fn with_claude() -> Self {
3494 Self::new(#crate_path::agent::impls::ClaudeCodeAgent::new())
3495 }
3496
3497 pub fn with_claude_model(model: &str) -> Self {
3499 Self::new(
3500 #crate_path::agent::impls::ClaudeCodeAgent::new()
3501 .with_model_str(model)
3502 )
3503 }
3504 }
3505 }
3506 }
3507 "gemini" => {
3508 quote! {
3509 impl #struct_name {
3510 pub fn with_gemini() -> Self {
3512 Self::new(#crate_path::agent::impls::GeminiAgent::new())
3513 }
3514
3515 pub fn with_gemini_model(model: &str) -> Self {
3517 Self::new(
3518 #crate_path::agent::impls::GeminiAgent::new()
3519 .with_model_str(model)
3520 )
3521 }
3522 }
3523 }
3524 }
3525 _ => quote! {},
3526 }
3527}
3528
3529fn generate_default_impl(
3531 struct_name: &syn::Ident,
3532 backend: &str,
3533 model: Option<&str>,
3534 profile: Option<&str>,
3535 crate_path: &proc_macro2::TokenStream,
3536) -> proc_macro2::TokenStream {
3537 let profile_expr = if let Some(profile_str) = profile {
3539 match profile_str.to_lowercase().as_str() {
3540 "creative" => quote! { #crate_path::agent::ExecutionProfile::Creative },
3541 "balanced" => quote! { #crate_path::agent::ExecutionProfile::Balanced },
3542 "deterministic" => quote! { #crate_path::agent::ExecutionProfile::Deterministic },
3543 _ => quote! { #crate_path::agent::ExecutionProfile::Balanced }, }
3545 } else {
3546 quote! { #crate_path::agent::ExecutionProfile::default() }
3547 };
3548
3549 let agent_init = match backend {
3550 "gemini" => {
3551 let mut builder = quote! { #crate_path::agent::impls::GeminiAgent::new() };
3552
3553 if let Some(model_str) = model {
3554 builder = quote! { #builder.with_model_str(#model_str) };
3555 }
3556
3557 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3558 builder
3559 }
3560 _ => {
3561 let mut builder = quote! { #crate_path::agent::impls::ClaudeCodeAgent::new() };
3563
3564 if let Some(model_str) = model {
3565 builder = quote! { #builder.with_model_str(#model_str) };
3566 }
3567
3568 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3569 builder
3570 }
3571 };
3572
3573 quote! {
3574 impl Default for #struct_name {
3575 fn default() -> Self {
3576 Self::new(#agent_init)
3577 }
3578 }
3579 }
3580}
3581
3582#[proc_macro_derive(Agent, attributes(agent))]
3591pub fn derive_agent(input: TokenStream) -> TokenStream {
3592 let input = parse_macro_input!(input as DeriveInput);
3593 let struct_name = &input.ident;
3594
3595 let agent_attrs = match parse_agent_attrs(&input.attrs) {
3597 Ok(attrs) => attrs,
3598 Err(e) => return e.to_compile_error().into(),
3599 };
3600
3601 let expertise = agent_attrs
3602 .expertise
3603 .unwrap_or_else(|| String::from("general AI assistant"));
3604 let output_type = agent_attrs
3605 .output
3606 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3607 let backend = agent_attrs
3608 .backend
3609 .unwrap_or_else(|| String::from("claude"));
3610 let model = agent_attrs.model;
3611 let _profile = agent_attrs.profile; let max_retries = agent_attrs.max_retries.unwrap_or(3); let found_crate =
3616 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3617 let crate_path = match found_crate {
3618 FoundCrate::Itself => {
3619 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3621 quote!(::#ident)
3622 }
3623 FoundCrate::Name(name) => {
3624 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3625 quote!(::#ident)
3626 }
3627 };
3628
3629 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
3630
3631 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3633 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3634
3635 let enhanced_expertise = if is_string_output {
3637 quote! { #expertise }
3639 } else {
3640 let type_name = quote!(#output_type).to_string();
3642 quote! {
3643 {
3644 use std::sync::OnceLock;
3645 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3646
3647 EXPERTISE_CACHE.get_or_init(|| {
3648 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3650
3651 if schema.is_empty() {
3652 format!(
3654 concat!(
3655 #expertise,
3656 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3657 "Do not include any text outside the JSON object."
3658 ),
3659 #type_name
3660 )
3661 } else {
3662 format!(
3664 concat!(
3665 #expertise,
3666 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3667 ),
3668 schema
3669 )
3670 }
3671 }).as_str()
3672 }
3673 }
3674 };
3675
3676 let agent_init = match backend.as_str() {
3678 "gemini" => {
3679 if let Some(model_str) = model {
3680 quote! {
3681 use #crate_path::agent::impls::GeminiAgent;
3682 let agent = GeminiAgent::new().with_model_str(#model_str);
3683 }
3684 } else {
3685 quote! {
3686 use #crate_path::agent::impls::GeminiAgent;
3687 let agent = GeminiAgent::new();
3688 }
3689 }
3690 }
3691 "claude" => {
3692 if let Some(model_str) = model {
3693 quote! {
3694 use #crate_path::agent::impls::ClaudeCodeAgent;
3695 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3696 }
3697 } else {
3698 quote! {
3699 use #crate_path::agent::impls::ClaudeCodeAgent;
3700 let agent = ClaudeCodeAgent::new();
3701 }
3702 }
3703 }
3704 _ => {
3705 if let Some(model_str) = model {
3707 quote! {
3708 use #crate_path::agent::impls::ClaudeCodeAgent;
3709 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3710 }
3711 } else {
3712 quote! {
3713 use #crate_path::agent::impls::ClaudeCodeAgent;
3714 let agent = ClaudeCodeAgent::new();
3715 }
3716 }
3717 }
3718 };
3719
3720 let expanded = quote! {
3721 #[async_trait::async_trait]
3722 impl #impl_generics #crate_path::agent::Agent for #struct_name #ty_generics #where_clause {
3723 type Output = #output_type;
3724
3725 fn expertise(&self) -> &str {
3726 #enhanced_expertise
3727 }
3728
3729 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3730 #agent_init
3732
3733 let agent_ref = &agent;
3735 #crate_path::agent::retry::retry_execution(
3736 #max_retries,
3737 &intent,
3738 move |payload| {
3739 let payload = payload.clone();
3740 async move {
3741 let response = agent_ref.execute(payload).await?;
3743
3744 let json_str = #crate_path::extract_json(&response)
3746 .map_err(|e| #crate_path::agent::AgentError::ParseError {
3747 message: format!("Failed to extract JSON: {}", e),
3748 reason: #crate_path::agent::error::ParseErrorReason::MarkdownExtractionFailed,
3749 })?;
3750
3751 serde_json::from_str::<Self::Output>(&json_str)
3753 .map_err(|e| {
3754 let reason = if e.is_eof() {
3756 #crate_path::agent::error::ParseErrorReason::UnexpectedEof
3757 } else if e.is_syntax() {
3758 #crate_path::agent::error::ParseErrorReason::InvalidJson
3759 } else {
3760 #crate_path::agent::error::ParseErrorReason::SchemaMismatch
3761 };
3762
3763 #crate_path::agent::AgentError::ParseError {
3764 message: format!("Failed to parse JSON: {}", e),
3765 reason,
3766 }
3767 })
3768 }
3769 }
3770 ).await
3771 }
3772
3773 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3774 #agent_init
3776 agent.is_available().await
3777 }
3778 }
3779 };
3780
3781 TokenStream::from(expanded)
3782}
3783
3784#[proc_macro_attribute]
3799pub fn agent(attr: TokenStream, item: TokenStream) -> TokenStream {
3800 let agent_attrs = match syn::parse::<AgentAttrs>(attr) {
3802 Ok(attrs) => attrs,
3803 Err(e) => return e.to_compile_error().into(),
3804 };
3805
3806 let input = parse_macro_input!(item as DeriveInput);
3808 let struct_name = &input.ident;
3809 let vis = &input.vis;
3810
3811 let expertise = agent_attrs
3812 .expertise
3813 .unwrap_or_else(|| String::from("general AI assistant"));
3814 let output_type = agent_attrs
3815 .output
3816 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3817 let backend = agent_attrs
3818 .backend
3819 .unwrap_or_else(|| String::from("claude"));
3820 let model = agent_attrs.model;
3821 let profile = agent_attrs.profile;
3822
3823 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3825 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3826
3827 let found_crate =
3829 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3830 let crate_path = match found_crate {
3831 FoundCrate::Itself => {
3832 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3833 quote!(::#ident)
3834 }
3835 FoundCrate::Name(name) => {
3836 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3837 quote!(::#ident)
3838 }
3839 };
3840
3841 let inner_generic_name = agent_attrs.inner.unwrap_or_else(|| String::from("A"));
3843 let inner_generic_ident = syn::Ident::new(&inner_generic_name, proc_macro2::Span::call_site());
3844
3845 let default_agent_type = if let Some(ref custom_type) = agent_attrs.default_inner {
3847 let type_path: syn::Type =
3849 syn::parse_str(custom_type).expect("default_inner must be a valid type path");
3850 quote! { #type_path }
3851 } else {
3852 match backend.as_str() {
3854 "gemini" => quote! { #crate_path::agent::impls::GeminiAgent },
3855 _ => quote! { #crate_path::agent::impls::ClaudeCodeAgent },
3856 }
3857 };
3858
3859 let struct_def = quote! {
3861 #vis struct #struct_name<#inner_generic_ident = #default_agent_type> {
3862 inner: #inner_generic_ident,
3863 }
3864 };
3865
3866 let constructors = quote! {
3868 impl<#inner_generic_ident> #struct_name<#inner_generic_ident> {
3869 pub fn new(inner: #inner_generic_ident) -> Self {
3871 Self { inner }
3872 }
3873 }
3874 };
3875
3876 let (backend_constructors, default_impl) = if agent_attrs.default_inner.is_some() {
3878 let default_impl = quote! {
3880 impl Default for #struct_name {
3881 fn default() -> Self {
3882 Self {
3883 inner: <#default_agent_type as Default>::default(),
3884 }
3885 }
3886 }
3887 };
3888 (quote! {}, default_impl)
3889 } else {
3890 let backend_constructors = generate_backend_constructors(
3892 struct_name,
3893 &backend,
3894 model.as_deref(),
3895 profile.as_deref(),
3896 &crate_path,
3897 );
3898 let default_impl = generate_default_impl(
3899 struct_name,
3900 &backend,
3901 model.as_deref(),
3902 profile.as_deref(),
3903 &crate_path,
3904 );
3905 (backend_constructors, default_impl)
3906 };
3907
3908 let enhanced_expertise = if is_string_output {
3910 quote! { #expertise }
3912 } else {
3913 let type_name = quote!(#output_type).to_string();
3915 quote! {
3916 {
3917 use std::sync::OnceLock;
3918 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3919
3920 EXPERTISE_CACHE.get_or_init(|| {
3921 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3923
3924 if schema.is_empty() {
3925 format!(
3927 concat!(
3928 #expertise,
3929 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3930 "Do not include any text outside the JSON object."
3931 ),
3932 #type_name
3933 )
3934 } else {
3935 format!(
3937 concat!(
3938 #expertise,
3939 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3940 ),
3941 schema
3942 )
3943 }
3944 }).as_str()
3945 }
3946 }
3947 };
3948
3949 let agent_impl = quote! {
3951 #[async_trait::async_trait]
3952 impl<#inner_generic_ident> #crate_path::agent::Agent for #struct_name<#inner_generic_ident>
3953 where
3954 #inner_generic_ident: #crate_path::agent::Agent<Output = String>,
3955 {
3956 type Output = #output_type;
3957
3958 fn expertise(&self) -> &str {
3959 #enhanced_expertise
3960 }
3961
3962 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3963 let enhanced_payload = intent.prepend_text(self.expertise());
3965
3966 let response = self.inner.execute(enhanced_payload).await?;
3968
3969 let json_str = #crate_path::extract_json(&response)
3971 .map_err(|e| #crate_path::agent::AgentError::ParseError {
3972 message: e.to_string(),
3973 reason: #crate_path::agent::error::ParseErrorReason::MarkdownExtractionFailed,
3974 })?;
3975
3976 serde_json::from_str(&json_str).map_err(|e| {
3978 let reason = if e.is_eof() {
3979 #crate_path::agent::error::ParseErrorReason::UnexpectedEof
3980 } else if e.is_syntax() {
3981 #crate_path::agent::error::ParseErrorReason::InvalidJson
3982 } else {
3983 #crate_path::agent::error::ParseErrorReason::SchemaMismatch
3984 };
3985 #crate_path::agent::AgentError::ParseError {
3986 message: e.to_string(),
3987 reason,
3988 }
3989 })
3990 }
3991
3992 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3993 self.inner.is_available().await
3994 }
3995 }
3996 };
3997
3998 let expanded = quote! {
3999 #struct_def
4000 #constructors
4001 #backend_constructors
4002 #default_impl
4003 #agent_impl
4004 };
4005
4006 TokenStream::from(expanded)
4007}
4008
4009#[proc_macro_derive(TypeMarker)]
4031pub fn derive_type_marker(input: TokenStream) -> TokenStream {
4032 let input = parse_macro_input!(input as DeriveInput);
4033 let struct_name = &input.ident;
4034 let type_name_str = struct_name.to_string();
4035
4036 let found_crate =
4038 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
4039 let crate_path = match found_crate {
4040 FoundCrate::Itself => {
4041 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
4042 quote!(::#ident)
4043 }
4044 FoundCrate::Name(name) => {
4045 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
4046 quote!(::#ident)
4047 }
4048 };
4049
4050 let expanded = quote! {
4051 impl #crate_path::orchestrator::TypeMarker for #struct_name {
4052 const TYPE_NAME: &'static str = #type_name_str;
4053 }
4054 };
4055
4056 TokenStream::from(expanded)
4057}
4058
4059#[proc_macro_attribute]
4095pub fn type_marker(_attr: TokenStream, item: TokenStream) -> TokenStream {
4096 let input = parse_macro_input!(item as syn::DeriveInput);
4097 let struct_name = &input.ident;
4098 let vis = &input.vis;
4099 let type_name_str = struct_name.to_string();
4100
4101 let default_fn_name = syn::Ident::new(
4103 &format!("default_{}_type", to_snake_case(&type_name_str)),
4104 struct_name.span(),
4105 );
4106
4107 let found_crate =
4109 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
4110 let crate_path = match found_crate {
4111 FoundCrate::Itself => {
4112 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
4113 quote!(::#ident)
4114 }
4115 FoundCrate::Name(name) => {
4116 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
4117 quote!(::#ident)
4118 }
4119 };
4120
4121 let fields = match &input.data {
4123 syn::Data::Struct(data_struct) => match &data_struct.fields {
4124 syn::Fields::Named(fields) => &fields.named,
4125 _ => {
4126 return syn::Error::new_spanned(
4127 struct_name,
4128 "type_marker only works with structs with named fields",
4129 )
4130 .to_compile_error()
4131 .into();
4132 }
4133 },
4134 _ => {
4135 return syn::Error::new_spanned(struct_name, "type_marker only works with structs")
4136 .to_compile_error()
4137 .into();
4138 }
4139 };
4140
4141 let mut new_fields = vec![];
4143
4144 let default_fn_name_str = default_fn_name.to_string();
4146 let default_fn_name_lit = syn::LitStr::new(&default_fn_name_str, default_fn_name.span());
4147
4148 new_fields.push(quote! {
4153 #[serde(default = #default_fn_name_lit)]
4154 __type: String
4155 });
4156
4157 for field in fields {
4159 new_fields.push(quote! { #field });
4160 }
4161
4162 let attrs = &input.attrs;
4164 let generics = &input.generics;
4165
4166 let expanded = quote! {
4167 fn #default_fn_name() -> String {
4169 #type_name_str.to_string()
4170 }
4171
4172 #(#attrs)*
4174 #vis struct #struct_name #generics {
4175 #(#new_fields),*
4176 }
4177
4178 impl #crate_path::orchestrator::TypeMarker for #struct_name {
4180 const TYPE_NAME: &'static str = #type_name_str;
4181 }
4182 };
4183
4184 TokenStream::from(expanded)
4185}