1use proc_macro::TokenStream;
2use proc_macro_crate::{FoundCrate, crate_name};
3use quote::quote;
4use regex::Regex;
5use syn::{
6 Data, DeriveInput, Meta, Token,
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 punctuated::Punctuated,
10};
11
12fn parse_template_placeholders_with_mode(template: &str) -> Vec<(String, Option<String>)> {
15 let mut placeholders = Vec::new();
16 let mut seen_fields = std::collections::HashSet::new();
17
18 let mode_pattern = Regex::new(r"\{\{\s*(\w+)\s*:\s*(\w+)\s*\}\}").unwrap();
20 for cap in mode_pattern.captures_iter(template) {
21 let field_name = cap[1].to_string();
22 let mode = cap[2].to_string();
23 placeholders.push((field_name.clone(), Some(mode)));
24 seen_fields.insert(field_name);
25 }
26
27 let standard_pattern = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
29 for cap in standard_pattern.captures_iter(template) {
30 let field_name = cap[1].to_string();
31 if !seen_fields.contains(&field_name) {
33 placeholders.push((field_name, None));
34 }
35 }
36
37 placeholders
38}
39
40fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
42 attrs
43 .iter()
44 .filter_map(|attr| {
45 if attr.path().is_ident("doc")
46 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
47 && let syn::Expr::Lit(syn::ExprLit {
48 lit: syn::Lit::Str(lit_str),
49 ..
50 }) = &meta_name_value.value
51 {
52 return Some(lit_str.value());
53 }
54 None
55 })
56 .map(|s| s.trim().to_string())
57 .collect::<Vec<_>>()
58 .join(" ")
59}
60
61fn generate_example_only_parts(
63 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
64 has_default: bool,
65 crate_path: &proc_macro2::TokenStream,
66) -> proc_macro2::TokenStream {
67 let mut field_values = Vec::new();
68
69 for field in fields.iter() {
70 let field_name = field.ident.as_ref().unwrap();
71 let field_name_str = field_name.to_string();
72 let attrs = parse_field_prompt_attrs(&field.attrs);
73
74 if field_name_str == "__type" {
78 continue;
79 }
80
81 if attrs.skip {
83 continue;
84 }
85
86 if let Some(example) = attrs.example {
88 field_values.push(quote! {
90 json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
91 });
92 } else if has_default {
93 field_values.push(quote! {
95 let default_value = serde_json::to_value(&default_instance.#field_name)
96 .unwrap_or(serde_json::Value::Null);
97 json_obj.insert(#field_name_str.to_string(), default_value);
98 });
99 } else {
100 field_values.push(quote! {
102 let value = serde_json::to_value(&self.#field_name)
103 .unwrap_or(serde_json::Value::Null);
104 json_obj.insert(#field_name_str.to_string(), value);
105 });
106 }
107 }
108
109 if has_default {
110 quote! {
111 {
112 let default_instance = Self::default();
113 let mut json_obj = serde_json::Map::new();
114 #(#field_values)*
115 let json_value = serde_json::Value::Object(json_obj);
116 let json_str = serde_json::to_string_pretty(&json_value)
117 .unwrap_or_else(|_| "{}".to_string());
118 vec![#crate_path::prompt::PromptPart::Text(json_str)]
119 }
120 }
121 } else {
122 quote! {
123 {
124 let mut json_obj = serde_json::Map::new();
125 #(#field_values)*
126 let json_value = serde_json::Value::Object(json_obj);
127 let json_str = serde_json::to_string_pretty(&json_value)
128 .unwrap_or_else(|_| "{}".to_string());
129 vec![#crate_path::prompt::PromptPart::Text(json_str)]
130 }
131 }
132 }
133}
134
135fn generate_schema_only_parts(
137 struct_name: &str,
138 struct_docs: &str,
139 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
140 crate_path: &proc_macro2::TokenStream,
141 _has_type_marker: bool,
142) -> proc_macro2::TokenStream {
143 let mut field_schema_parts = vec![];
144 let mut nested_type_collectors = vec![];
145
146 for field in fields.iter() {
148 let field_name = field.ident.as_ref().unwrap();
149 let field_name_str = field_name.to_string();
150 let attrs = parse_field_prompt_attrs(&field.attrs);
151
152 if field_name_str == "__type" {
156 continue;
157 }
158
159 if attrs.skip {
161 continue;
162 }
163
164 let field_docs = extract_doc_comments(&field.attrs);
166
167 let (is_vec, inner_type) = extract_vec_inner_type(&field.ty);
169
170 if is_vec {
171 let comment = if !field_docs.is_empty() {
174 format!(" // {}", field_docs)
175 } else {
176 String::new()
177 };
178
179 field_schema_parts.push(quote! {
180 {
181 let type_name = stringify!(#inner_type);
182 format!(" {}: {}[];{}", #field_name_str, type_name, #comment)
183 }
184 });
185
186 if let Some(inner) = inner_type
188 && !is_primitive_type(inner)
189 {
190 nested_type_collectors.push(quote! {
191 <#inner as #crate_path::prompt::ToPrompt>::prompt_schema()
192 });
193 }
194 } else {
195 let field_type = &field.ty;
197 let is_primitive = is_primitive_type(field_type);
198
199 if !is_primitive {
200 let comment = if !field_docs.is_empty() {
203 format!(" // {}", field_docs)
204 } else {
205 String::new()
206 };
207
208 field_schema_parts.push(quote! {
209 {
210 let type_name = stringify!(#field_type);
211 format!(" {}: {};{}", #field_name_str, type_name, #comment)
212 }
213 });
214
215 nested_type_collectors.push(quote! {
217 <#field_type as #crate_path::prompt::ToPrompt>::prompt_schema()
218 });
219 } else {
220 let type_str = format_type_for_schema(&field.ty);
223 let comment = if !field_docs.is_empty() {
224 format!(" // {}", field_docs)
225 } else {
226 String::new()
227 };
228
229 field_schema_parts.push(quote! {
230 format!(" {}: {};{}", #field_name_str, #type_str, #comment)
231 });
232 }
233 }
234 }
235
236 let mut header_lines = Vec::new();
251
252 if !struct_docs.is_empty() {
254 header_lines.push("/**".to_string());
255 header_lines.push(format!(" * {}", struct_docs));
256 header_lines.push(" */".to_string());
257 }
258
259 header_lines.push(format!("type {} = {{", struct_name));
261
262 quote! {
263 {
264 let mut all_lines: Vec<String> = Vec::new();
265
266 let nested_schemas: Vec<String> = vec![#(#nested_type_collectors),*];
268 let mut seen_types = std::collections::HashSet::<String>::new();
269
270 for schema in nested_schemas {
271 if !schema.is_empty() {
272 if seen_types.insert(schema.clone()) {
274 all_lines.push(schema);
275 all_lines.push(String::new()); }
277 }
278 }
279
280 let mut lines: Vec<String> = Vec::new();
282 #(lines.push(#header_lines.to_string());)*
283 #(lines.push(#field_schema_parts);)*
284 lines.push("}".to_string());
285 all_lines.push(lines.join("\n"));
286
287 vec![#crate_path::prompt::PromptPart::Text(all_lines.join("\n"))]
288 }
289 }
290}
291
292fn extract_vec_inner_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
294 if let syn::Type::Path(type_path) = ty
295 && let Some(last_segment) = type_path.path.segments.last()
296 && last_segment.ident == "Vec"
297 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
298 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
299 {
300 return (true, Some(inner_type));
301 }
302 (false, None)
303}
304
305fn is_primitive_type(ty: &syn::Type) -> bool {
307 if let syn::Type::Path(type_path) = ty
308 && let Some(last_segment) = type_path.path.segments.last()
309 {
310 let type_name = last_segment.ident.to_string();
311 matches!(
312 type_name.as_str(),
313 "String"
314 | "str"
315 | "i8"
316 | "i16"
317 | "i32"
318 | "i64"
319 | "i128"
320 | "isize"
321 | "u8"
322 | "u16"
323 | "u32"
324 | "u64"
325 | "u128"
326 | "usize"
327 | "f32"
328 | "f64"
329 | "bool"
330 | "Vec"
331 | "Option"
332 | "HashMap"
333 | "BTreeMap"
334 | "HashSet"
335 | "BTreeSet"
336 )
337 } else {
338 true
340 }
341}
342
343fn format_type_for_schema(ty: &syn::Type) -> String {
345 match ty {
347 syn::Type::Path(type_path) => {
348 let path = &type_path.path;
349 if let Some(last_segment) = path.segments.last() {
350 let type_name = last_segment.ident.to_string();
351
352 if type_name == "Option"
354 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
355 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
356 {
357 return format!("{} | null", format_type_for_schema(inner_type));
358 }
359
360 match type_name.as_str() {
362 "String" | "str" => "string".to_string(),
363 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
364 | "u64" | "u128" | "usize" => "number".to_string(),
365 "f32" | "f64" => "number".to_string(),
366 "bool" => "boolean".to_string(),
367 "Vec" => {
368 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
369 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
370 {
371 return format!("{}[]", format_type_for_schema(inner_type));
372 }
373 "array".to_string()
374 }
375 _ => type_name.to_lowercase(),
376 }
377 } else {
378 "unknown".to_string()
379 }
380 }
381 _ => "unknown".to_string(),
382 }
383}
384
385#[derive(Default)]
387struct PromptAttributes {
388 skip: bool,
389 rename: Option<String>,
390 description: Option<String>,
391}
392
393fn parse_prompt_attributes(attrs: &[syn::Attribute]) -> PromptAttributes {
396 let mut result = PromptAttributes::default();
397
398 for attr in attrs {
399 if attr.path().is_ident("prompt") {
400 if let Ok(meta_list) = attr.meta.require_list() {
402 if let Ok(metas) =
404 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
405 {
406 for meta in metas {
407 if let Meta::NameValue(nv) = meta {
408 if nv.path.is_ident("rename") {
409 if let syn::Expr::Lit(syn::ExprLit {
410 lit: syn::Lit::Str(lit_str),
411 ..
412 }) = nv.value
413 {
414 result.rename = Some(lit_str.value());
415 }
416 } else if nv.path.is_ident("description")
417 && let syn::Expr::Lit(syn::ExprLit {
418 lit: syn::Lit::Str(lit_str),
419 ..
420 }) = nv.value
421 {
422 result.description = Some(lit_str.value());
423 }
424 } else if let Meta::Path(path) = meta
425 && path.is_ident("skip")
426 {
427 result.skip = true;
428 }
429 }
430 }
431
432 let tokens_str = meta_list.tokens.to_string();
434 if tokens_str == "skip" {
435 result.skip = true;
436 }
437 }
438
439 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
441 result.description = Some(lit_str.value());
442 }
443 }
444 }
445 result
446}
447
448fn generate_example_value_for_type(type_str: &str) -> String {
450 match type_str {
451 "string" => "\"example\"".to_string(),
452 "number" => "0".to_string(),
453 "boolean" => "false".to_string(),
454 s if s.ends_with("[]") => "[]".to_string(),
455 s if s.contains("|") => {
456 let first_type = s.split('|').next().unwrap().trim();
458 generate_example_value_for_type(first_type)
459 }
460 _ => "null".to_string(),
461 }
462}
463
464fn parse_serde_variant_rename(attrs: &[syn::Attribute]) -> Option<String> {
466 for attr in attrs {
467 if attr.path().is_ident("serde")
468 && let Ok(meta_list) = attr.meta.require_list()
469 && let Ok(metas) =
470 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
471 {
472 for meta in metas {
473 if let Meta::NameValue(nv) = meta
474 && nv.path.is_ident("rename")
475 && let syn::Expr::Lit(syn::ExprLit {
476 lit: syn::Lit::Str(lit_str),
477 ..
478 }) = nv.value
479 {
480 return Some(lit_str.value());
481 }
482 }
483 }
484 }
485 None
486}
487
488#[derive(Debug, Clone, Copy, PartialEq, Eq)]
490enum RenameRule {
491 #[allow(dead_code)]
492 None,
493 LowerCase,
494 UpperCase,
495 PascalCase,
496 CamelCase,
497 SnakeCase,
498 ScreamingSnakeCase,
499 KebabCase,
500 ScreamingKebabCase,
501}
502
503impl RenameRule {
504 fn from_str(s: &str) -> Option<Self> {
506 match s {
507 "lowercase" => Some(Self::LowerCase),
508 "UPPERCASE" => Some(Self::UpperCase),
509 "PascalCase" => Some(Self::PascalCase),
510 "camelCase" => Some(Self::CamelCase),
511 "snake_case" => Some(Self::SnakeCase),
512 "SCREAMING_SNAKE_CASE" => Some(Self::ScreamingSnakeCase),
513 "kebab-case" => Some(Self::KebabCase),
514 "SCREAMING-KEBAB-CASE" => Some(Self::ScreamingKebabCase),
515 _ => None,
516 }
517 }
518
519 fn apply(&self, name: &str) -> String {
521 match self {
522 Self::None => name.to_string(),
523 Self::LowerCase => name.to_lowercase(),
524 Self::UpperCase => name.to_uppercase(),
525 Self::PascalCase => name.to_string(), Self::CamelCase => {
527 let mut chars = name.chars();
529 match chars.next() {
530 None => String::new(),
531 Some(first) => first.to_lowercase().chain(chars).collect(),
532 }
533 }
534 Self::SnakeCase => {
535 let mut result = String::new();
537 for (i, ch) in name.chars().enumerate() {
538 if ch.is_uppercase() && i > 0 {
539 result.push('_');
540 }
541 result.push(ch.to_lowercase().next().unwrap());
542 }
543 result
544 }
545 Self::ScreamingSnakeCase => {
546 let mut result = String::new();
548 for (i, ch) in name.chars().enumerate() {
549 if ch.is_uppercase() && i > 0 {
550 result.push('_');
551 }
552 result.push(ch.to_uppercase().next().unwrap());
553 }
554 result
555 }
556 Self::KebabCase => {
557 let mut result = String::new();
559 for (i, ch) in name.chars().enumerate() {
560 if ch.is_uppercase() && i > 0 {
561 result.push('-');
562 }
563 result.push(ch.to_lowercase().next().unwrap());
564 }
565 result
566 }
567 Self::ScreamingKebabCase => {
568 let mut result = String::new();
570 for (i, ch) in name.chars().enumerate() {
571 if ch.is_uppercase() && i > 0 {
572 result.push('-');
573 }
574 result.push(ch.to_uppercase().next().unwrap());
575 }
576 result
577 }
578 }
579 }
580}
581
582fn parse_serde_rename_all(attrs: &[syn::Attribute]) -> Option<RenameRule> {
584 for attr in attrs {
585 if attr.path().is_ident("serde")
586 && let Ok(meta_list) = attr.meta.require_list()
587 {
588 if let Ok(metas) =
590 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
591 {
592 for meta in metas {
593 if let Meta::NameValue(nv) = meta
594 && nv.path.is_ident("rename_all")
595 && let syn::Expr::Lit(syn::ExprLit {
596 lit: syn::Lit::Str(lit_str),
597 ..
598 }) = nv.value
599 {
600 return RenameRule::from_str(&lit_str.value());
601 }
602 }
603 }
604 }
605 }
606 None
607}
608
609fn parse_serde_tag(attrs: &[syn::Attribute]) -> Option<String> {
612 for attr in attrs {
613 if attr.path().is_ident("serde")
614 && let Ok(meta_list) = attr.meta.require_list()
615 {
616 if let Ok(metas) =
618 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
619 {
620 for meta in metas {
621 if let Meta::NameValue(nv) = meta
622 && nv.path.is_ident("tag")
623 && let syn::Expr::Lit(syn::ExprLit {
624 lit: syn::Lit::Str(lit_str),
625 ..
626 }) = nv.value
627 {
628 return Some(lit_str.value());
629 }
630 }
631 }
632 }
633 }
634 None
635}
636
637fn parse_serde_untagged(attrs: &[syn::Attribute]) -> bool {
640 for attr in attrs {
641 if attr.path().is_ident("serde")
642 && let Ok(meta_list) = attr.meta.require_list()
643 {
644 if let Ok(metas) =
646 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
647 {
648 for meta in metas {
649 if let Meta::Path(path) = meta
650 && path.is_ident("untagged")
651 {
652 return true;
653 }
654 }
655 }
656 }
657 }
658 false
659}
660
661#[derive(Debug, Default)]
663struct FieldPromptAttrs {
664 skip: bool,
665 rename: Option<String>,
666 format_with: Option<String>,
667 image: bool,
668 example: Option<String>,
669}
670
671fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
673 let mut result = FieldPromptAttrs::default();
674
675 for attr in attrs {
676 if attr.path().is_ident("prompt") {
677 if let Ok(meta_list) = attr.meta.require_list() {
679 if let Ok(metas) =
681 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
682 {
683 for meta in metas {
684 match meta {
685 Meta::Path(path) if path.is_ident("skip") => {
686 result.skip = true;
687 }
688 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
689 if let syn::Expr::Lit(syn::ExprLit {
690 lit: syn::Lit::Str(lit_str),
691 ..
692 }) = nv.value
693 {
694 result.rename = Some(lit_str.value());
695 }
696 }
697 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
698 if let syn::Expr::Lit(syn::ExprLit {
699 lit: syn::Lit::Str(lit_str),
700 ..
701 }) = nv.value
702 {
703 result.format_with = Some(lit_str.value());
704 }
705 }
706 Meta::Path(path) if path.is_ident("image") => {
707 result.image = true;
708 }
709 Meta::NameValue(nv) if nv.path.is_ident("example") => {
710 if let syn::Expr::Lit(syn::ExprLit {
711 lit: syn::Lit::Str(lit_str),
712 ..
713 }) = nv.value
714 {
715 result.example = Some(lit_str.value());
716 }
717 }
718 _ => {}
719 }
720 }
721 } else if meta_list.tokens.to_string() == "skip" {
722 result.skip = true;
724 } else if meta_list.tokens.to_string() == "image" {
725 result.image = true;
727 }
728 }
729 }
730 }
731
732 result
733}
734
735#[proc_macro_derive(ToPrompt, attributes(prompt))]
778pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
779 let input = parse_macro_input!(input as DeriveInput);
780
781 let found_crate =
782 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
783 let crate_path = match found_crate {
784 FoundCrate::Itself => {
785 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
787 quote!(::#ident)
788 }
789 FoundCrate::Name(name) => {
790 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
791 quote!(::#ident)
792 }
793 };
794
795 match &input.data {
797 Data::Enum(data_enum) => {
798 let enum_name = &input.ident;
800 let enum_docs = extract_doc_comments(&input.attrs);
801
802 let serde_tag = parse_serde_tag(&input.attrs);
804 let is_internally_tagged = serde_tag.is_some();
805 let is_untagged = parse_serde_untagged(&input.attrs);
806
807 let rename_rule = parse_serde_rename_all(&input.attrs);
809
810 let mut variant_lines = Vec::new();
823 let mut first_variant_name = None;
824
825 let mut example_unit: Option<String> = None;
827 let mut example_struct: Option<String> = None;
828 let mut example_tuple: Option<String> = None;
829
830 for variant in &data_enum.variants {
831 let variant_name = &variant.ident;
832 let variant_name_str = variant_name.to_string();
833
834 let prompt_attrs = parse_prompt_attributes(&variant.attrs);
836
837 if prompt_attrs.skip {
839 continue;
840 }
841
842 let variant_value = if let Some(prompt_rename) = &prompt_attrs.rename {
848 prompt_rename.clone()
849 } else if let Some(serde_rename) = parse_serde_variant_rename(&variant.attrs) {
850 serde_rename
851 } else if let Some(rule) = rename_rule {
852 rule.apply(&variant_name_str)
853 } else {
854 variant_name_str.clone()
855 };
856
857 let variant_line = match &variant.fields {
859 syn::Fields::Unit => {
860 if example_unit.is_none() {
862 example_unit = Some(format!("\"{}\"", variant_value));
863 }
864
865 if let Some(desc) = &prompt_attrs.description {
867 format!(" | \"{}\" // {}", variant_value, desc)
868 } else {
869 let docs = extract_doc_comments(&variant.attrs);
870 if !docs.is_empty() {
871 format!(" | \"{}\" // {}", variant_value, docs)
872 } else {
873 format!(" | \"{}\"", variant_value)
874 }
875 }
876 }
877 syn::Fields::Named(fields) => {
878 let mut field_parts = Vec::new();
879 let mut example_field_parts = Vec::new();
880
881 if is_internally_tagged && let Some(tag_name) = &serde_tag {
883 field_parts.push(format!("{}: \"{}\"", tag_name, variant_value));
884 example_field_parts
885 .push(format!("{}: \"{}\"", tag_name, variant_value));
886 }
887
888 for field in &fields.named {
889 let field_name = field.ident.as_ref().unwrap().to_string();
890 let field_type = format_type_for_schema(&field.ty);
891 field_parts.push(format!("{}: {}", field_name, field_type.clone()));
892
893 let example_value = generate_example_value_for_type(&field_type);
895 example_field_parts.push(format!("{}: {}", field_name, example_value));
896 }
897
898 let field_str = field_parts.join(", ");
899 let example_field_str = example_field_parts.join(", ");
900
901 if example_struct.is_none() {
903 if is_untagged || is_internally_tagged {
904 example_struct = Some(format!("{{ {} }}", example_field_str));
905 } else {
906 example_struct = Some(format!(
907 "{{ \"{}\": {{ {} }} }}",
908 variant_value, example_field_str
909 ));
910 }
911 }
912
913 let comment = if let Some(desc) = &prompt_attrs.description {
914 format!(" // {}", desc)
915 } else {
916 let docs = extract_doc_comments(&variant.attrs);
917 if !docs.is_empty() {
918 format!(" // {}", docs)
919 } else if is_untagged {
920 format!(" // {}", variant_value)
922 } else {
923 String::new()
924 }
925 };
926
927 if is_untagged {
928 format!(" | {{ {} }}{}", field_str, comment)
930 } else if is_internally_tagged {
931 format!(" | {{ {} }}{}", field_str, comment)
933 } else {
934 format!(
936 " | {{ \"{}\": {{ {} }} }}{}",
937 variant_value, field_str, comment
938 )
939 }
940 }
941 syn::Fields::Unnamed(fields) => {
942 let field_types: Vec<String> = fields
943 .unnamed
944 .iter()
945 .map(|f| format_type_for_schema(&f.ty))
946 .collect();
947
948 let tuple_str = field_types.join(", ");
949
950 let example_values: Vec<String> = field_types
952 .iter()
953 .map(|type_str| generate_example_value_for_type(type_str))
954 .collect();
955 let example_tuple_str = example_values.join(", ");
956
957 if example_tuple.is_none() {
959 if is_untagged || is_internally_tagged {
960 example_tuple = Some(format!("[{}]", example_tuple_str));
961 } else {
962 example_tuple = Some(format!(
963 "{{ \"{}\": [{}] }}",
964 variant_value, example_tuple_str
965 ));
966 }
967 }
968
969 let comment = if let Some(desc) = &prompt_attrs.description {
970 format!(" // {}", desc)
971 } else {
972 let docs = extract_doc_comments(&variant.attrs);
973 if !docs.is_empty() {
974 format!(" // {}", docs)
975 } else if is_untagged {
976 format!(" // {}", variant_value)
978 } else {
979 String::new()
980 }
981 };
982
983 if is_untagged || is_internally_tagged {
984 format!(" | [{}]{}", tuple_str, comment)
987 } else {
988 format!(
990 " | {{ \"{}\": [{}] }}{}",
991 variant_value, tuple_str, comment
992 )
993 }
994 }
995 };
996
997 variant_lines.push(variant_line);
998
999 if first_variant_name.is_none() {
1000 first_variant_name = Some(variant_value);
1001 }
1002 }
1003
1004 let mut lines = Vec::new();
1006
1007 if !enum_docs.is_empty() {
1009 lines.push("/**".to_string());
1010 lines.push(format!(" * {}", enum_docs));
1011 lines.push(" */".to_string());
1012 }
1013
1014 lines.push(format!("type {} =", enum_name));
1016
1017 for line in &variant_lines {
1019 lines.push(line.clone());
1020 }
1021
1022 if let Some(last) = lines.last_mut()
1024 && !last.ends_with(';')
1025 {
1026 last.push(';');
1027 }
1028
1029 let mut examples = Vec::new();
1031 if let Some(ex) = example_unit {
1032 examples.push(ex);
1033 }
1034 if let Some(ex) = example_struct {
1035 examples.push(ex);
1036 }
1037 if let Some(ex) = example_tuple {
1038 examples.push(ex);
1039 }
1040
1041 if !examples.is_empty() {
1042 lines.push("".to_string()); if examples.len() == 1 {
1044 lines.push(format!("Example value: {}", examples[0]));
1045 } else {
1046 lines.push("Example values:".to_string());
1047 for ex in examples {
1048 lines.push(format!(" {}", ex));
1049 }
1050 }
1051 }
1052
1053 let prompt_string = lines.join("\n");
1054 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1055
1056 let mut match_arms = Vec::new();
1058 for variant in &data_enum.variants {
1059 let variant_name = &variant.ident;
1060 let variant_name_str = variant_name.to_string();
1061
1062 let prompt_attrs = parse_prompt_attributes(&variant.attrs);
1064
1065 let variant_value = if let Some(prompt_rename) = &prompt_attrs.rename {
1071 prompt_rename.clone()
1072 } else if let Some(serde_rename) = parse_serde_variant_rename(&variant.attrs) {
1073 serde_rename
1074 } else if let Some(rule) = rename_rule {
1075 rule.apply(&variant_name_str)
1076 } else {
1077 variant_name_str.clone()
1078 };
1079
1080 match &variant.fields {
1082 syn::Fields::Unit => {
1083 if prompt_attrs.skip {
1085 match_arms.push(quote! {
1086 Self::#variant_name => stringify!(#variant_name).to_string()
1087 });
1088 } else if let Some(desc) = &prompt_attrs.description {
1089 match_arms.push(quote! {
1090 Self::#variant_name => format!("{}: {}", #variant_value, #desc)
1091 });
1092 } else {
1093 let variant_docs = extract_doc_comments(&variant.attrs);
1094 if !variant_docs.is_empty() {
1095 match_arms.push(quote! {
1096 Self::#variant_name => format!("{}: {}", #variant_value, #variant_docs)
1097 });
1098 } else {
1099 match_arms.push(quote! {
1100 Self::#variant_name => #variant_value.to_string()
1101 });
1102 }
1103 }
1104 }
1105 syn::Fields::Named(fields) => {
1106 let field_bindings: Vec<_> = fields
1108 .named
1109 .iter()
1110 .map(|f| f.ident.as_ref().unwrap())
1111 .collect();
1112
1113 let field_displays: Vec<_> = fields
1114 .named
1115 .iter()
1116 .map(|f| {
1117 let field_name = f.ident.as_ref().unwrap();
1118 let field_name_str = field_name.to_string();
1119 quote! {
1120 format!("{}: {:?}", #field_name_str, #field_name)
1121 }
1122 })
1123 .collect();
1124
1125 let doc_or_desc = if let Some(desc) = &prompt_attrs.description {
1126 desc.clone()
1127 } else {
1128 let docs = extract_doc_comments(&variant.attrs);
1129 if !docs.is_empty() {
1130 docs
1131 } else {
1132 String::new()
1133 }
1134 };
1135
1136 if doc_or_desc.is_empty() {
1137 match_arms.push(quote! {
1138 Self::#variant_name { #(#field_bindings),* } => {
1139 let fields = vec![#(#field_displays),*];
1140 format!("{} {{ {} }}", #variant_value, fields.join(", "))
1141 }
1142 });
1143 } else {
1144 match_arms.push(quote! {
1145 Self::#variant_name { #(#field_bindings),* } => {
1146 let fields = vec![#(#field_displays),*];
1147 format!("{}: {} {{ {} }}", #variant_value, #doc_or_desc, fields.join(", "))
1148 }
1149 });
1150 }
1151 }
1152 syn::Fields::Unnamed(fields) => {
1153 let field_count = fields.unnamed.len();
1155 let field_bindings: Vec<_> = (0..field_count)
1156 .map(|i| {
1157 syn::Ident::new(
1158 &format!("field{}", i),
1159 proc_macro2::Span::call_site(),
1160 )
1161 })
1162 .collect();
1163
1164 let field_displays: Vec<_> = field_bindings
1165 .iter()
1166 .map(|field_name| {
1167 quote! {
1168 format!("{:?}", #field_name)
1169 }
1170 })
1171 .collect();
1172
1173 let doc_or_desc = if let Some(desc) = &prompt_attrs.description {
1174 desc.clone()
1175 } else {
1176 let docs = extract_doc_comments(&variant.attrs);
1177 if !docs.is_empty() {
1178 docs
1179 } else {
1180 String::new()
1181 }
1182 };
1183
1184 if doc_or_desc.is_empty() {
1185 match_arms.push(quote! {
1186 Self::#variant_name(#(#field_bindings),*) => {
1187 let fields = vec![#(#field_displays),*];
1188 format!("{}({})", #variant_value, fields.join(", "))
1189 }
1190 });
1191 } else {
1192 match_arms.push(quote! {
1193 Self::#variant_name(#(#field_bindings),*) => {
1194 let fields = vec![#(#field_displays),*];
1195 format!("{}: {}({})", #variant_value, #doc_or_desc, fields.join(", "))
1196 }
1197 });
1198 }
1199 }
1200 }
1201 }
1202
1203 let to_prompt_impl = if match_arms.is_empty() {
1204 quote! {
1206 fn to_prompt(&self) -> String {
1207 match *self {}
1208 }
1209 }
1210 } else {
1211 quote! {
1212 fn to_prompt(&self) -> String {
1213 match self {
1214 #(#match_arms),*
1215 }
1216 }
1217 }
1218 };
1219
1220 let expanded = quote! {
1221 impl #impl_generics #crate_path::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
1222 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1223 vec![#crate_path::prompt::PromptPart::Text(self.to_prompt())]
1224 }
1225
1226 #to_prompt_impl
1227
1228 fn prompt_schema() -> String {
1229 #prompt_string.to_string()
1230 }
1231 }
1232 };
1233
1234 TokenStream::from(expanded)
1235 }
1236 Data::Struct(data_struct) => {
1237 let mut template_attr = None;
1239 let mut template_file_attr = None;
1240 let mut mode_attr = None;
1241 let mut validate_attr = false;
1242 let mut type_marker_attr = false;
1243
1244 for attr in &input.attrs {
1245 if attr.path().is_ident("prompt") {
1246 if let Ok(metas) =
1248 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1249 {
1250 for meta in metas {
1251 match meta {
1252 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1253 if let syn::Expr::Lit(expr_lit) = nv.value
1254 && let syn::Lit::Str(lit_str) = expr_lit.lit
1255 {
1256 template_attr = Some(lit_str.value());
1257 }
1258 }
1259 Meta::NameValue(nv) if nv.path.is_ident("template_file") => {
1260 if let syn::Expr::Lit(expr_lit) = nv.value
1261 && let syn::Lit::Str(lit_str) = expr_lit.lit
1262 {
1263 template_file_attr = Some(lit_str.value());
1264 }
1265 }
1266 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
1267 if let syn::Expr::Lit(expr_lit) = nv.value
1268 && let syn::Lit::Str(lit_str) = expr_lit.lit
1269 {
1270 mode_attr = Some(lit_str.value());
1271 }
1272 }
1273 Meta::NameValue(nv) if nv.path.is_ident("validate") => {
1274 if let syn::Expr::Lit(expr_lit) = nv.value
1275 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
1276 {
1277 validate_attr = lit_bool.value();
1278 }
1279 }
1280 Meta::NameValue(nv) if nv.path.is_ident("type_marker") => {
1281 if let syn::Expr::Lit(expr_lit) = nv.value
1282 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
1283 {
1284 type_marker_attr = lit_bool.value();
1285 }
1286 }
1287 Meta::Path(path) if path.is_ident("type_marker") => {
1288 type_marker_attr = true;
1290 }
1291 _ => {}
1292 }
1293 }
1294 }
1295 }
1296 }
1297
1298 if template_attr.is_some() && template_file_attr.is_some() {
1300 return syn::Error::new(
1301 input.ident.span(),
1302 "The `template` and `template_file` attributes are mutually exclusive. Please use only one.",
1303 ).to_compile_error().into();
1304 }
1305
1306 let template_str = if let Some(file_path) = template_file_attr {
1308 let mut full_path = None;
1312
1313 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
1315 let is_trybuild = manifest_dir.contains("target/tests/trybuild");
1317
1318 if !is_trybuild {
1319 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
1321 if candidate.exists() {
1322 full_path = Some(candidate);
1323 }
1324 } else {
1325 if let Some(target_pos) = manifest_dir.find("/target/tests/trybuild") {
1331 let workspace_root = &manifest_dir[..target_pos];
1332 let original_macros_dir = std::path::Path::new(workspace_root)
1334 .join("crates")
1335 .join("llm-toolkit-macros");
1336
1337 let candidate = original_macros_dir.join(&file_path);
1338 if candidate.exists() {
1339 full_path = Some(candidate);
1340 }
1341 }
1342 }
1343 }
1344
1345 if full_path.is_none() {
1347 let candidate = std::path::Path::new(&file_path).to_path_buf();
1348 if candidate.exists() {
1349 full_path = Some(candidate);
1350 }
1351 }
1352
1353 if full_path.is_none()
1356 && let Ok(current_dir) = std::env::current_dir()
1357 {
1358 let mut search_dir = current_dir.as_path();
1359 for _ in 0..10 {
1361 let macros_dir = search_dir.join("crates/llm-toolkit-macros");
1363 if macros_dir.exists() {
1364 let candidate = macros_dir.join(&file_path);
1365 if candidate.exists() {
1366 full_path = Some(candidate);
1367 break;
1368 }
1369 }
1370 let candidate = search_dir.join(&file_path);
1372 if candidate.exists() {
1373 full_path = Some(candidate);
1374 break;
1375 }
1376 if let Some(parent) = search_dir.parent() {
1377 search_dir = parent;
1378 } else {
1379 break;
1380 }
1381 }
1382 }
1383
1384 if full_path.is_none() {
1386 let mut error_msg = format!(
1388 "Template file '{}' not found at compile time.\n\nSearched in:",
1389 file_path
1390 );
1391
1392 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
1393 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
1394 error_msg.push_str(&format!("\n - {}", candidate.display()));
1395 }
1396
1397 if let Ok(current_dir) = std::env::current_dir() {
1398 let candidate = current_dir.join(&file_path);
1399 error_msg.push_str(&format!("\n - {}", candidate.display()));
1400 }
1401
1402 error_msg.push_str("\n\nPlease ensure:");
1403 error_msg.push_str("\n 1. The template file exists");
1404 error_msg.push_str("\n 2. The path is relative to CARGO_MANIFEST_DIR");
1405 error_msg.push_str("\n 3. There are no typos in the path");
1406
1407 return syn::Error::new(input.ident.span(), error_msg)
1408 .to_compile_error()
1409 .into();
1410 }
1411
1412 let final_path = full_path.unwrap();
1413
1414 match std::fs::read_to_string(&final_path) {
1416 Ok(content) => Some(content),
1417 Err(e) => {
1418 return syn::Error::new(
1419 input.ident.span(),
1420 format!(
1421 "Failed to read template file '{}': {}\n\nPath resolved to: {}",
1422 file_path,
1423 e,
1424 final_path.display()
1425 ),
1426 )
1427 .to_compile_error()
1428 .into();
1429 }
1430 }
1431 } else {
1432 template_attr
1433 };
1434
1435 if validate_attr && let Some(template) = &template_str {
1437 let mut env = minijinja::Environment::new();
1439 if let Err(e) = env.add_template("validation", template) {
1440 let warning_msg =
1442 format!("Template validation warning: Invalid Jinja syntax - {}", e);
1443 let warning_ident = syn::Ident::new(
1444 "TEMPLATE_VALIDATION_WARNING",
1445 proc_macro2::Span::call_site(),
1446 );
1447 let _warning_tokens = quote! {
1448 #[deprecated(note = #warning_msg)]
1449 const #warning_ident: () = ();
1450 let _ = #warning_ident;
1451 };
1452 eprintln!("cargo:warning={}", warning_msg);
1454 }
1455
1456 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1458 &fields.named
1459 } else {
1460 panic!("Template validation is only supported for structs with named fields.");
1461 };
1462
1463 let field_names: std::collections::HashSet<String> = fields
1464 .iter()
1465 .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
1466 .collect();
1467
1468 let placeholders = parse_template_placeholders_with_mode(template);
1470
1471 for (placeholder_name, _mode) in &placeholders {
1472 if placeholder_name != "self" && !field_names.contains(placeholder_name) {
1473 let warning_msg = format!(
1474 "Template validation warning: Variable '{}' used in template but not found in struct fields",
1475 placeholder_name
1476 );
1477 eprintln!("cargo:warning={}", warning_msg);
1478 }
1479 }
1480 }
1481
1482 let name = input.ident;
1483 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1484
1485 let struct_docs = extract_doc_comments(&input.attrs);
1487
1488 let is_mode_based =
1490 mode_attr.is_some() || (template_str.is_none() && struct_docs.contains("mode"));
1491
1492 let expanded = if is_mode_based || mode_attr.is_some() {
1493 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1495 &fields.named
1496 } else {
1497 panic!(
1498 "Mode-based prompt generation is only supported for structs with named fields."
1499 );
1500 };
1501
1502 let struct_name_str = name.to_string();
1503
1504 let has_default = input.attrs.iter().any(|attr| {
1506 if attr.path().is_ident("derive")
1507 && let Ok(meta_list) = attr.meta.require_list()
1508 {
1509 let tokens_str = meta_list.tokens.to_string();
1510 tokens_str.contains("Default")
1511 } else {
1512 false
1513 }
1514 });
1515
1516 let schema_parts = generate_schema_only_parts(
1527 &struct_name_str,
1528 &struct_docs,
1529 fields,
1530 &crate_path,
1531 type_marker_attr,
1532 );
1533
1534 let example_parts = generate_example_only_parts(fields, has_default, &crate_path);
1536
1537 quote! {
1538 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1539 fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<#crate_path::prompt::PromptPart> {
1540 match mode {
1541 "schema_only" => #schema_parts,
1542 "example_only" => #example_parts,
1543 "full" | _ => {
1544 let mut parts = Vec::new();
1546
1547 let schema_parts = #schema_parts;
1549 parts.extend(schema_parts);
1550
1551 parts.push(#crate_path::prompt::PromptPart::Text("\n### Example".to_string()));
1553 parts.push(#crate_path::prompt::PromptPart::Text(
1554 format!("Here is an example of a valid `{}` object:", #struct_name_str)
1555 ));
1556
1557 let example_parts = #example_parts;
1559 parts.extend(example_parts);
1560
1561 parts
1562 }
1563 }
1564 }
1565
1566 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1567 self.to_prompt_parts_with_mode("full")
1568 }
1569
1570 fn to_prompt(&self) -> String {
1571 self.to_prompt_parts()
1572 .into_iter()
1573 .filter_map(|part| match part {
1574 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1575 _ => None,
1576 })
1577 .collect::<Vec<_>>()
1578 .join("\n")
1579 }
1580
1581 fn prompt_schema() -> String {
1582 use std::sync::OnceLock;
1583 static SCHEMA_CACHE: OnceLock<String> = OnceLock::new();
1584
1585 SCHEMA_CACHE.get_or_init(|| {
1586 let schema_parts = #schema_parts;
1587 schema_parts
1588 .into_iter()
1589 .filter_map(|part| match part {
1590 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1591 _ => None,
1592 })
1593 .collect::<Vec<_>>()
1594 .join("\n")
1595 }).clone()
1596 }
1597 }
1598 }
1599 } else if let Some(template) = template_str {
1600 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1603 &fields.named
1604 } else {
1605 panic!(
1606 "Template prompt generation is only supported for structs with named fields."
1607 );
1608 };
1609
1610 let placeholders = parse_template_placeholders_with_mode(&template);
1612 let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
1614 mode.is_some()
1615 && fields
1616 .iter()
1617 .any(|f| f.ident.as_ref().unwrap() == field_name)
1618 });
1619
1620 let mut image_field_parts = Vec::new();
1621 for f in fields.iter() {
1622 let field_name = f.ident.as_ref().unwrap();
1623 let attrs = parse_field_prompt_attrs(&f.attrs);
1624
1625 if attrs.image {
1626 image_field_parts.push(quote! {
1628 parts.extend(self.#field_name.to_prompt_parts());
1629 });
1630 }
1631 }
1632
1633 if has_mode_syntax {
1635 let mut context_fields = Vec::new();
1637 let mut modified_template = template.clone();
1638
1639 for (field_name, mode_opt) in &placeholders {
1641 if let Some(mode) = mode_opt {
1642 let unique_key = format!("{}__{}", field_name, mode);
1644
1645 let pattern = format!("{{{{ {}:{} }}}}", field_name, mode);
1647 let replacement = format!("{{{{ {} }}}}", unique_key);
1648 modified_template = modified_template.replace(&pattern, &replacement);
1649
1650 let field_ident =
1652 syn::Ident::new(field_name, proc_macro2::Span::call_site());
1653
1654 context_fields.push(quote! {
1656 context.insert(
1657 #unique_key.to_string(),
1658 minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
1659 );
1660 });
1661 }
1662 }
1663
1664 for field in fields.iter() {
1666 let field_name = field.ident.as_ref().unwrap();
1667 let field_name_str = field_name.to_string();
1668
1669 let has_mode_entry = placeholders
1671 .iter()
1672 .any(|(name, mode)| name == &field_name_str && mode.is_some());
1673
1674 if !has_mode_entry {
1675 let is_primitive = match &field.ty {
1678 syn::Type::Path(type_path) => {
1679 if let Some(segment) = type_path.path.segments.last() {
1680 let type_name = segment.ident.to_string();
1681 matches!(
1682 type_name.as_str(),
1683 "String"
1684 | "str"
1685 | "i8"
1686 | "i16"
1687 | "i32"
1688 | "i64"
1689 | "i128"
1690 | "isize"
1691 | "u8"
1692 | "u16"
1693 | "u32"
1694 | "u64"
1695 | "u128"
1696 | "usize"
1697 | "f32"
1698 | "f64"
1699 | "bool"
1700 | "char"
1701 )
1702 } else {
1703 false
1704 }
1705 }
1706 _ => false,
1707 };
1708
1709 if is_primitive {
1710 context_fields.push(quote! {
1711 context.insert(
1712 #field_name_str.to_string(),
1713 minijinja::Value::from_serialize(&self.#field_name)
1714 );
1715 });
1716 } else {
1717 context_fields.push(quote! {
1719 context.insert(
1720 #field_name_str.to_string(),
1721 minijinja::Value::from(self.#field_name.to_prompt())
1722 );
1723 });
1724 }
1725 }
1726 }
1727
1728 quote! {
1729 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1730 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1731 let mut parts = Vec::new();
1732
1733 #(#image_field_parts)*
1735
1736 let text = {
1738 let mut env = minijinja::Environment::new();
1739 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1740 panic!("Failed to parse template: {}", e)
1741 });
1742
1743 let tmpl = env.get_template("prompt").unwrap();
1744
1745 let mut context = std::collections::HashMap::new();
1746 #(#context_fields)*
1747
1748 tmpl.render(context).unwrap_or_else(|e| {
1749 format!("Failed to render prompt: {}", e)
1750 })
1751 };
1752
1753 if !text.is_empty() {
1754 parts.push(#crate_path::prompt::PromptPart::Text(text));
1755 }
1756
1757 parts
1758 }
1759
1760 fn to_prompt(&self) -> String {
1761 let mut env = minijinja::Environment::new();
1763 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1764 panic!("Failed to parse template: {}", e)
1765 });
1766
1767 let tmpl = env.get_template("prompt").unwrap();
1768
1769 let mut context = std::collections::HashMap::new();
1770 #(#context_fields)*
1771
1772 tmpl.render(context).unwrap_or_else(|e| {
1773 format!("Failed to render prompt: {}", e)
1774 })
1775 }
1776
1777 fn prompt_schema() -> String {
1778 String::new() }
1780 }
1781 }
1782 } else {
1783 quote! {
1785 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1786 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1787 let mut parts = Vec::new();
1788
1789 #(#image_field_parts)*
1791
1792 let text = #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1794 format!("Failed to render prompt: {}", e)
1795 });
1796 if !text.is_empty() {
1797 parts.push(#crate_path::prompt::PromptPart::Text(text));
1798 }
1799
1800 parts
1801 }
1802
1803 fn to_prompt(&self) -> String {
1804 #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1805 format!("Failed to render prompt: {}", e)
1806 })
1807 }
1808
1809 fn prompt_schema() -> String {
1810 String::new() }
1812 }
1813 }
1814 }
1815 } else {
1816 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1819 &fields.named
1820 } else {
1821 panic!(
1822 "Default prompt generation is only supported for structs with named fields."
1823 );
1824 };
1825
1826 let mut text_field_parts = Vec::new();
1828 let mut image_field_parts = Vec::new();
1829
1830 for f in fields.iter() {
1831 let field_name = f.ident.as_ref().unwrap();
1832 let attrs = parse_field_prompt_attrs(&f.attrs);
1833
1834 if attrs.skip {
1836 continue;
1837 }
1838
1839 if attrs.image {
1840 image_field_parts.push(quote! {
1842 parts.extend(self.#field_name.to_prompt_parts());
1843 });
1844 } else {
1845 let key = if let Some(rename) = attrs.rename {
1851 rename
1852 } else {
1853 let doc_comment = extract_doc_comments(&f.attrs);
1854 if !doc_comment.is_empty() {
1855 doc_comment
1856 } else {
1857 field_name.to_string()
1858 }
1859 };
1860
1861 let value_expr = if let Some(format_with) = attrs.format_with {
1863 let func_path: syn::Path =
1865 syn::parse_str(&format_with).unwrap_or_else(|_| {
1866 panic!("Invalid function path: {}", format_with)
1867 });
1868 quote! { #func_path(&self.#field_name) }
1869 } else {
1870 quote! { self.#field_name.to_prompt() }
1871 };
1872
1873 text_field_parts.push(quote! {
1874 text_parts.push(format!("{}: {}", #key, #value_expr));
1875 });
1876 }
1877 }
1878
1879 quote! {
1881 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1882 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1883 let mut parts = Vec::new();
1884
1885 #(#image_field_parts)*
1887
1888 let mut text_parts = Vec::new();
1890 #(#text_field_parts)*
1891
1892 if !text_parts.is_empty() {
1893 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1894 }
1895
1896 parts
1897 }
1898
1899 fn to_prompt(&self) -> String {
1900 let mut text_parts = Vec::new();
1901 #(#text_field_parts)*
1902 text_parts.join("\n")
1903 }
1904
1905 fn prompt_schema() -> String {
1906 String::new() }
1908 }
1909 }
1910 };
1911
1912 TokenStream::from(expanded)
1913 }
1914 Data::Union(_) => {
1915 panic!("`#[derive(ToPrompt)]` is not supported for unions");
1916 }
1917 }
1918}
1919
1920#[derive(Debug, Clone)]
1922struct TargetInfo {
1923 name: String,
1924 template: Option<String>,
1925 field_configs: std::collections::HashMap<String, FieldTargetConfig>,
1926}
1927
1928#[derive(Debug, Clone, Default)]
1930struct FieldTargetConfig {
1931 skip: bool,
1932 rename: Option<String>,
1933 format_with: Option<String>,
1934 image: bool,
1935 include_only: bool, }
1937
1938fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
1940 let mut configs = Vec::new();
1941
1942 for attr in attrs {
1943 if attr.path().is_ident("prompt_for")
1944 && let Ok(meta_list) = attr.meta.require_list()
1945 {
1946 if meta_list.tokens.to_string() == "skip" {
1948 let config = FieldTargetConfig {
1950 skip: true,
1951 ..Default::default()
1952 };
1953 configs.push(("*".to_string(), config));
1954 } else if let Ok(metas) =
1955 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1956 {
1957 let mut target_name = None;
1958 let mut config = FieldTargetConfig::default();
1959
1960 for meta in metas {
1961 match meta {
1962 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1963 if let syn::Expr::Lit(syn::ExprLit {
1964 lit: syn::Lit::Str(lit_str),
1965 ..
1966 }) = nv.value
1967 {
1968 target_name = Some(lit_str.value());
1969 }
1970 }
1971 Meta::Path(path) if path.is_ident("skip") => {
1972 config.skip = true;
1973 }
1974 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
1975 if let syn::Expr::Lit(syn::ExprLit {
1976 lit: syn::Lit::Str(lit_str),
1977 ..
1978 }) = nv.value
1979 {
1980 config.rename = Some(lit_str.value());
1981 }
1982 }
1983 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
1984 if let syn::Expr::Lit(syn::ExprLit {
1985 lit: syn::Lit::Str(lit_str),
1986 ..
1987 }) = nv.value
1988 {
1989 config.format_with = Some(lit_str.value());
1990 }
1991 }
1992 Meta::Path(path) if path.is_ident("image") => {
1993 config.image = true;
1994 }
1995 _ => {}
1996 }
1997 }
1998
1999 if let Some(name) = target_name {
2000 config.include_only = true;
2001 configs.push((name, config));
2002 }
2003 }
2004 }
2005 }
2006
2007 configs
2008}
2009
2010fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
2012 let mut targets = Vec::new();
2013
2014 for attr in attrs {
2015 if attr.path().is_ident("prompt_for")
2016 && let Ok(meta_list) = attr.meta.require_list()
2017 && let Ok(metas) =
2018 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2019 {
2020 let mut target_name = None;
2021 let mut template = None;
2022
2023 for meta in metas {
2024 match meta {
2025 Meta::NameValue(nv) if nv.path.is_ident("name") => {
2026 if let syn::Expr::Lit(syn::ExprLit {
2027 lit: syn::Lit::Str(lit_str),
2028 ..
2029 }) = nv.value
2030 {
2031 target_name = Some(lit_str.value());
2032 }
2033 }
2034 Meta::NameValue(nv) if nv.path.is_ident("template") => {
2035 if let syn::Expr::Lit(syn::ExprLit {
2036 lit: syn::Lit::Str(lit_str),
2037 ..
2038 }) = nv.value
2039 {
2040 template = Some(lit_str.value());
2041 }
2042 }
2043 _ => {}
2044 }
2045 }
2046
2047 if let Some(name) = target_name {
2048 targets.push(TargetInfo {
2049 name,
2050 template,
2051 field_configs: std::collections::HashMap::new(),
2052 });
2053 }
2054 }
2055 }
2056
2057 targets
2058}
2059
2060#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
2061pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
2062 let input = parse_macro_input!(input as DeriveInput);
2063
2064 let found_crate =
2065 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2066 let crate_path = match found_crate {
2067 FoundCrate::Itself => {
2068 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2070 quote!(::#ident)
2071 }
2072 FoundCrate::Name(name) => {
2073 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2074 quote!(::#ident)
2075 }
2076 };
2077
2078 let data_struct = match &input.data {
2080 Data::Struct(data) => data,
2081 _ => {
2082 return syn::Error::new(
2083 input.ident.span(),
2084 "`#[derive(ToPromptSet)]` is only supported for structs",
2085 )
2086 .to_compile_error()
2087 .into();
2088 }
2089 };
2090
2091 let fields = match &data_struct.fields {
2092 syn::Fields::Named(fields) => &fields.named,
2093 _ => {
2094 return syn::Error::new(
2095 input.ident.span(),
2096 "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
2097 )
2098 .to_compile_error()
2099 .into();
2100 }
2101 };
2102
2103 let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
2105
2106 for field in fields.iter() {
2108 let field_name = field.ident.as_ref().unwrap().to_string();
2109 let field_configs = parse_prompt_for_attrs(&field.attrs);
2110
2111 for (target_name, config) in field_configs {
2112 if target_name == "*" {
2113 for target in &mut targets {
2115 target
2116 .field_configs
2117 .entry(field_name.clone())
2118 .or_insert_with(FieldTargetConfig::default)
2119 .skip = config.skip;
2120 }
2121 } else {
2122 let target_exists = targets.iter().any(|t| t.name == target_name);
2124 if !target_exists {
2125 targets.push(TargetInfo {
2127 name: target_name.clone(),
2128 template: None,
2129 field_configs: std::collections::HashMap::new(),
2130 });
2131 }
2132
2133 let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
2134
2135 target.field_configs.insert(field_name.clone(), config);
2136 }
2137 }
2138 }
2139
2140 let mut match_arms = Vec::new();
2142
2143 for target in &targets {
2144 let target_name = &target.name;
2145
2146 if let Some(template_str) = &target.template {
2147 let mut image_parts = Vec::new();
2149
2150 for field in fields.iter() {
2151 let field_name = field.ident.as_ref().unwrap();
2152 let field_name_str = field_name.to_string();
2153
2154 if let Some(config) = target.field_configs.get(&field_name_str)
2155 && config.image
2156 {
2157 image_parts.push(quote! {
2158 parts.extend(self.#field_name.to_prompt_parts());
2159 });
2160 }
2161 }
2162
2163 match_arms.push(quote! {
2164 #target_name => {
2165 let mut parts = Vec::new();
2166
2167 #(#image_parts)*
2168
2169 let text = #crate_path::prompt::render_prompt(#template_str, self)
2170 .map_err(|e| #crate_path::prompt::PromptSetError::RenderFailed {
2171 target: #target_name.to_string(),
2172 source: e,
2173 })?;
2174
2175 if !text.is_empty() {
2176 parts.push(#crate_path::prompt::PromptPart::Text(text));
2177 }
2178
2179 Ok(parts)
2180 }
2181 });
2182 } else {
2183 let mut text_field_parts = Vec::new();
2185 let mut image_field_parts = Vec::new();
2186
2187 for field in fields.iter() {
2188 let field_name = field.ident.as_ref().unwrap();
2189 let field_name_str = field_name.to_string();
2190
2191 let config = target.field_configs.get(&field_name_str);
2193
2194 if let Some(cfg) = config
2196 && cfg.skip
2197 {
2198 continue;
2199 }
2200
2201 let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
2205 let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
2206 .iter()
2207 .any(|(name, _)| name != "*");
2208
2209 if has_any_target_specific_config && !is_explicitly_for_this_target {
2210 continue;
2211 }
2212
2213 if let Some(cfg) = config {
2214 if cfg.image {
2215 image_field_parts.push(quote! {
2216 parts.extend(self.#field_name.to_prompt_parts());
2217 });
2218 } else {
2219 let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
2220
2221 let value_expr = if let Some(format_with) = &cfg.format_with {
2222 match syn::parse_str::<syn::Path>(format_with) {
2224 Ok(func_path) => quote! { #func_path(&self.#field_name) },
2225 Err(_) => {
2226 let error_msg = format!(
2228 "Invalid function path in format_with: '{}'",
2229 format_with
2230 );
2231 quote! {
2232 compile_error!(#error_msg);
2233 String::new()
2234 }
2235 }
2236 }
2237 } else {
2238 quote! { self.#field_name.to_prompt() }
2239 };
2240
2241 text_field_parts.push(quote! {
2242 text_parts.push(format!("{}: {}", #key, #value_expr));
2243 });
2244 }
2245 } else {
2246 text_field_parts.push(quote! {
2248 text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
2249 });
2250 }
2251 }
2252
2253 match_arms.push(quote! {
2254 #target_name => {
2255 let mut parts = Vec::new();
2256
2257 #(#image_field_parts)*
2258
2259 let mut text_parts = Vec::new();
2260 #(#text_field_parts)*
2261
2262 if !text_parts.is_empty() {
2263 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
2264 }
2265
2266 Ok(parts)
2267 }
2268 });
2269 }
2270 }
2271
2272 let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
2274
2275 match_arms.push(quote! {
2277 _ => {
2278 let available = vec![#(#target_names.to_string()),*];
2279 Err(#crate_path::prompt::PromptSetError::TargetNotFound {
2280 target: target.to_string(),
2281 available,
2282 })
2283 }
2284 });
2285
2286 let struct_name = &input.ident;
2287 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2288
2289 let expanded = quote! {
2290 impl #impl_generics #crate_path::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
2291 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<#crate_path::prompt::PromptPart>, #crate_path::prompt::PromptSetError> {
2292 match target {
2293 #(#match_arms)*
2294 }
2295 }
2296 }
2297 };
2298
2299 TokenStream::from(expanded)
2300}
2301
2302struct TypeList {
2304 types: Punctuated<syn::Type, Token![,]>,
2305}
2306
2307impl Parse for TypeList {
2308 fn parse(input: ParseStream) -> syn::Result<Self> {
2309 Ok(TypeList {
2310 types: Punctuated::parse_terminated(input)?,
2311 })
2312 }
2313}
2314
2315#[proc_macro]
2339pub fn examples_section(input: TokenStream) -> TokenStream {
2340 let input = parse_macro_input!(input as TypeList);
2341
2342 let found_crate =
2343 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2344 let _crate_path = match found_crate {
2345 FoundCrate::Itself => quote!(crate),
2346 FoundCrate::Name(name) => {
2347 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2348 quote!(::#ident)
2349 }
2350 };
2351
2352 let mut type_sections = Vec::new();
2354
2355 for ty in input.types.iter() {
2356 let type_name_str = quote!(#ty).to_string();
2358
2359 type_sections.push(quote! {
2361 {
2362 let type_name = #type_name_str;
2363 let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
2364 format!("---\n#### `{}`\n{}", type_name, json_example)
2365 }
2366 });
2367 }
2368
2369 let expanded = quote! {
2371 {
2372 let mut sections = Vec::new();
2373 sections.push("---".to_string());
2374 sections.push("### Examples".to_string());
2375 sections.push("".to_string());
2376 sections.push("Here are examples of the data structures you should use.".to_string());
2377 sections.push("".to_string());
2378
2379 #(sections.push(#type_sections);)*
2380
2381 sections.push("---".to_string());
2382
2383 sections.join("\n")
2384 }
2385 };
2386
2387 TokenStream::from(expanded)
2388}
2389
2390fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
2392 for attr in attrs {
2393 if attr.path().is_ident("prompt_for")
2394 && let Ok(meta_list) = attr.meta.require_list()
2395 && let Ok(metas) =
2396 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2397 {
2398 let mut target_type = None;
2399 let mut template = None;
2400
2401 for meta in metas {
2402 match meta {
2403 Meta::NameValue(nv) if nv.path.is_ident("target") => {
2404 if let syn::Expr::Lit(syn::ExprLit {
2405 lit: syn::Lit::Str(lit_str),
2406 ..
2407 }) = nv.value
2408 {
2409 target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
2411 }
2412 }
2413 Meta::NameValue(nv) if nv.path.is_ident("template") => {
2414 if let syn::Expr::Lit(syn::ExprLit {
2415 lit: syn::Lit::Str(lit_str),
2416 ..
2417 }) = nv.value
2418 {
2419 template = Some(lit_str.value());
2420 }
2421 }
2422 _ => {}
2423 }
2424 }
2425
2426 if let (Some(target), Some(tmpl)) = (target_type, template) {
2427 return (target, tmpl);
2428 }
2429 }
2430 }
2431
2432 panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
2433}
2434
2435#[proc_macro_attribute]
2469pub fn define_intent(_attr: TokenStream, item: TokenStream) -> TokenStream {
2470 let input = parse_macro_input!(item as DeriveInput);
2471
2472 let found_crate =
2473 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2474 let crate_path = match found_crate {
2475 FoundCrate::Itself => {
2476 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2478 quote!(::#ident)
2479 }
2480 FoundCrate::Name(name) => {
2481 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2482 quote!(::#ident)
2483 }
2484 };
2485
2486 let enum_data = match &input.data {
2488 Data::Enum(data) => data,
2489 _ => {
2490 return syn::Error::new(
2491 input.ident.span(),
2492 "`#[define_intent]` can only be applied to enums",
2493 )
2494 .to_compile_error()
2495 .into();
2496 }
2497 };
2498
2499 let mut prompt_template = None;
2501 let mut extractor_tag = None;
2502 let mut mode = None;
2503
2504 for attr in &input.attrs {
2505 if attr.path().is_ident("intent")
2506 && let Ok(metas) =
2507 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2508 {
2509 for meta in metas {
2510 match meta {
2511 Meta::NameValue(nv) if nv.path.is_ident("prompt") => {
2512 if let syn::Expr::Lit(syn::ExprLit {
2513 lit: syn::Lit::Str(lit_str),
2514 ..
2515 }) = nv.value
2516 {
2517 prompt_template = Some(lit_str.value());
2518 }
2519 }
2520 Meta::NameValue(nv) if nv.path.is_ident("extractor_tag") => {
2521 if let syn::Expr::Lit(syn::ExprLit {
2522 lit: syn::Lit::Str(lit_str),
2523 ..
2524 }) = nv.value
2525 {
2526 extractor_tag = Some(lit_str.value());
2527 }
2528 }
2529 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
2530 if let syn::Expr::Lit(syn::ExprLit {
2531 lit: syn::Lit::Str(lit_str),
2532 ..
2533 }) = nv.value
2534 {
2535 mode = Some(lit_str.value());
2536 }
2537 }
2538 _ => {}
2539 }
2540 }
2541 }
2542 }
2543
2544 let mode = mode.unwrap_or_else(|| "single".to_string());
2546
2547 if mode != "single" && mode != "multi_tag" {
2549 return syn::Error::new(
2550 input.ident.span(),
2551 "`mode` must be either \"single\" or \"multi_tag\"",
2552 )
2553 .to_compile_error()
2554 .into();
2555 }
2556
2557 let prompt_template = match prompt_template {
2559 Some(p) => p,
2560 None => {
2561 return syn::Error::new(
2562 input.ident.span(),
2563 "`#[intent(...)]` attribute must include `prompt = \"...\"`",
2564 )
2565 .to_compile_error()
2566 .into();
2567 }
2568 };
2569
2570 if mode == "multi_tag" {
2572 let enum_name = &input.ident;
2573 let actions_doc = generate_multi_tag_actions_doc(&enum_data.variants);
2574 return generate_multi_tag_output(
2575 &input,
2576 enum_name,
2577 enum_data,
2578 prompt_template,
2579 actions_doc,
2580 );
2581 }
2582
2583 let extractor_tag = match extractor_tag {
2585 Some(t) => t,
2586 None => {
2587 return syn::Error::new(
2588 input.ident.span(),
2589 "`#[intent(...)]` attribute must include `extractor_tag = \"...\"`",
2590 )
2591 .to_compile_error()
2592 .into();
2593 }
2594 };
2595
2596 let enum_name = &input.ident;
2598 let enum_docs = extract_doc_comments(&input.attrs);
2599
2600 let mut intents_doc_lines = Vec::new();
2601
2602 if !enum_docs.is_empty() {
2604 intents_doc_lines.push(format!("{}: {}", enum_name, enum_docs));
2605 } else {
2606 intents_doc_lines.push(format!("{}:", enum_name));
2607 }
2608 intents_doc_lines.push(String::new()); intents_doc_lines.push("Possible values:".to_string());
2610
2611 for variant in &enum_data.variants {
2613 let variant_name = &variant.ident;
2614 let variant_docs = extract_doc_comments(&variant.attrs);
2615
2616 if !variant_docs.is_empty() {
2617 intents_doc_lines.push(format!("- {}: {}", variant_name, variant_docs));
2618 } else {
2619 intents_doc_lines.push(format!("- {}", variant_name));
2620 }
2621 }
2622
2623 let intents_doc_str = intents_doc_lines.join("\n");
2624
2625 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2627 let user_variables: Vec<String> = placeholders
2628 .iter()
2629 .filter_map(|(name, _)| {
2630 if name != "intents_doc" {
2631 Some(name.clone())
2632 } else {
2633 None
2634 }
2635 })
2636 .collect();
2637
2638 let enum_name_str = enum_name.to_string();
2640 let snake_case_name = to_snake_case(&enum_name_str);
2641 let function_name = syn::Ident::new(
2642 &format!("build_{}_prompt", snake_case_name),
2643 proc_macro2::Span::call_site(),
2644 );
2645
2646 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2648 .iter()
2649 .map(|var| {
2650 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2651 quote! { #ident: &str }
2652 })
2653 .collect();
2654
2655 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2657 .iter()
2658 .map(|var| {
2659 let var_str = var.clone();
2660 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2661 quote! {
2662 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2663 }
2664 })
2665 .collect();
2666
2667 let converted_template = prompt_template.clone();
2669
2670 let extractor_name = syn::Ident::new(
2672 &format!("{}Extractor", enum_name),
2673 proc_macro2::Span::call_site(),
2674 );
2675
2676 let filtered_attrs: Vec<_> = input
2678 .attrs
2679 .iter()
2680 .filter(|attr| !attr.path().is_ident("intent"))
2681 .collect();
2682
2683 let vis = &input.vis;
2685 let generics = &input.generics;
2686 let variants = &enum_data.variants;
2687 let enum_output = quote! {
2688 #(#filtered_attrs)*
2689 #vis enum #enum_name #generics {
2690 #variants
2691 }
2692 };
2693
2694 let expanded = quote! {
2696 #enum_output
2698
2699 pub fn #function_name(#(#function_params),*) -> String {
2701 let mut env = minijinja::Environment::new();
2702 env.add_template("prompt", #converted_template)
2703 .expect("Failed to parse intent prompt template");
2704
2705 let tmpl = env.get_template("prompt").unwrap();
2706
2707 let mut __template_context = std::collections::HashMap::new();
2708
2709 __template_context.insert("intents_doc".to_string(), minijinja::Value::from(#intents_doc_str));
2711
2712 #(#context_insertions)*
2714
2715 tmpl.render(&__template_context)
2716 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2717 }
2718
2719 pub struct #extractor_name;
2721
2722 impl #extractor_name {
2723 pub const EXTRACTOR_TAG: &'static str = #extractor_tag;
2724 }
2725
2726 impl #crate_path::intent::IntentExtractor<#enum_name> for #extractor_name {
2727 fn extract_intent(&self, response: &str) -> Result<#enum_name, #crate_path::intent::IntentExtractionError> {
2728 #crate_path::intent::extract_intent_from_response(response, Self::EXTRACTOR_TAG)
2730 }
2731 }
2732 };
2733
2734 TokenStream::from(expanded)
2735}
2736
2737fn to_snake_case(s: &str) -> String {
2739 let mut result = String::new();
2740 let mut prev_upper = false;
2741
2742 for (i, ch) in s.chars().enumerate() {
2743 if ch.is_uppercase() {
2744 if i > 0 && !prev_upper {
2745 result.push('_');
2746 }
2747 result.push(ch.to_lowercase().next().unwrap());
2748 prev_upper = true;
2749 } else {
2750 result.push(ch);
2751 prev_upper = false;
2752 }
2753 }
2754
2755 result
2756}
2757
2758#[derive(Debug, Default)]
2760struct ActionAttrs {
2761 tag: Option<String>,
2762}
2763
2764fn parse_action_attrs(attrs: &[syn::Attribute]) -> ActionAttrs {
2765 let mut result = ActionAttrs::default();
2766
2767 for attr in attrs {
2768 if attr.path().is_ident("action")
2769 && let Ok(meta_list) = attr.meta.require_list()
2770 && let Ok(metas) =
2771 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2772 {
2773 for meta in metas {
2774 if let Meta::NameValue(nv) = meta
2775 && nv.path.is_ident("tag")
2776 && let syn::Expr::Lit(syn::ExprLit {
2777 lit: syn::Lit::Str(lit_str),
2778 ..
2779 }) = nv.value
2780 {
2781 result.tag = Some(lit_str.value());
2782 }
2783 }
2784 }
2785 }
2786
2787 result
2788}
2789
2790#[derive(Debug, Default)]
2792struct FieldActionAttrs {
2793 is_attribute: bool,
2794 is_inner_text: bool,
2795}
2796
2797fn parse_field_action_attrs(attrs: &[syn::Attribute]) -> FieldActionAttrs {
2798 let mut result = FieldActionAttrs::default();
2799
2800 for attr in attrs {
2801 if attr.path().is_ident("action")
2802 && let Ok(meta_list) = attr.meta.require_list()
2803 {
2804 let tokens_str = meta_list.tokens.to_string();
2805 if tokens_str == "attribute" {
2806 result.is_attribute = true;
2807 } else if tokens_str == "inner_text" {
2808 result.is_inner_text = true;
2809 }
2810 }
2811 }
2812
2813 result
2814}
2815
2816fn generate_multi_tag_actions_doc(
2818 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2819) -> String {
2820 let mut doc_lines = Vec::new();
2821
2822 for variant in variants {
2823 let action_attrs = parse_action_attrs(&variant.attrs);
2824
2825 if let Some(tag) = action_attrs.tag {
2826 let variant_docs = extract_doc_comments(&variant.attrs);
2827
2828 match &variant.fields {
2829 syn::Fields::Unit => {
2830 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2832 }
2833 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
2834 doc_lines.push(format!("- `<{}>...</{}>`: {}", tag, tag, variant_docs));
2836 }
2837 syn::Fields::Named(fields) => {
2838 let mut attrs_str = Vec::new();
2840 let mut has_inner_text = false;
2841
2842 for field in &fields.named {
2843 let field_name = field.ident.as_ref().unwrap();
2844 let field_attrs = parse_field_action_attrs(&field.attrs);
2845
2846 if field_attrs.is_attribute {
2847 attrs_str.push(format!("{}=\"...\"", field_name));
2848 } else if field_attrs.is_inner_text {
2849 has_inner_text = true;
2850 }
2851 }
2852
2853 let attrs_part = if !attrs_str.is_empty() {
2854 format!(" {}", attrs_str.join(" "))
2855 } else {
2856 String::new()
2857 };
2858
2859 if has_inner_text {
2860 doc_lines.push(format!(
2861 "- `<{}{}>...</{}>`: {}",
2862 tag, attrs_part, tag, variant_docs
2863 ));
2864 } else if !attrs_str.is_empty() {
2865 doc_lines.push(format!("- `<{}{} />`: {}", tag, attrs_part, variant_docs));
2866 } else {
2867 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2868 }
2869
2870 for field in &fields.named {
2872 let field_name = field.ident.as_ref().unwrap();
2873 let field_attrs = parse_field_action_attrs(&field.attrs);
2874 let field_docs = extract_doc_comments(&field.attrs);
2875
2876 if field_attrs.is_attribute {
2877 doc_lines
2878 .push(format!(" - `{}` (attribute): {}", field_name, field_docs));
2879 } else if field_attrs.is_inner_text {
2880 doc_lines
2881 .push(format!(" - `{}` (inner_text): {}", field_name, field_docs));
2882 }
2883 }
2884 }
2885 _ => {
2886 }
2888 }
2889 }
2890 }
2891
2892 doc_lines.join("\n")
2893}
2894
2895fn generate_tags_regex(
2897 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2898) -> String {
2899 let mut tag_names = Vec::new();
2900
2901 for variant in variants {
2902 let action_attrs = parse_action_attrs(&variant.attrs);
2903 if let Some(tag) = action_attrs.tag {
2904 tag_names.push(tag);
2905 }
2906 }
2907
2908 if tag_names.is_empty() {
2909 return String::new();
2910 }
2911
2912 let tags_pattern = tag_names.join("|");
2913 format!(
2916 r"(?is)<(?:{})\b[^>]*/>|<(?:{})\b[^>]*>.*?</(?:{})>",
2917 tags_pattern, tags_pattern, tags_pattern
2918 )
2919}
2920
2921fn generate_multi_tag_output(
2923 input: &DeriveInput,
2924 enum_name: &syn::Ident,
2925 enum_data: &syn::DataEnum,
2926 prompt_template: String,
2927 actions_doc: String,
2928) -> TokenStream {
2929 let found_crate =
2930 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2931 let crate_path = match found_crate {
2932 FoundCrate::Itself => {
2933 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2935 quote!(::#ident)
2936 }
2937 FoundCrate::Name(name) => {
2938 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2939 quote!(::#ident)
2940 }
2941 };
2942
2943 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2945 let user_variables: Vec<String> = placeholders
2946 .iter()
2947 .filter_map(|(name, _)| {
2948 if name != "actions_doc" {
2949 Some(name.clone())
2950 } else {
2951 None
2952 }
2953 })
2954 .collect();
2955
2956 let enum_name_str = enum_name.to_string();
2958 let snake_case_name = to_snake_case(&enum_name_str);
2959 let function_name = syn::Ident::new(
2960 &format!("build_{}_prompt", snake_case_name),
2961 proc_macro2::Span::call_site(),
2962 );
2963
2964 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2966 .iter()
2967 .map(|var| {
2968 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2969 quote! { #ident: &str }
2970 })
2971 .collect();
2972
2973 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2975 .iter()
2976 .map(|var| {
2977 let var_str = var.clone();
2978 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2979 quote! {
2980 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2981 }
2982 })
2983 .collect();
2984
2985 let extractor_name = syn::Ident::new(
2987 &format!("{}Extractor", enum_name),
2988 proc_macro2::Span::call_site(),
2989 );
2990
2991 let filtered_attrs: Vec<_> = input
2993 .attrs
2994 .iter()
2995 .filter(|attr| !attr.path().is_ident("intent"))
2996 .collect();
2997
2998 let filtered_variants: Vec<proc_macro2::TokenStream> = enum_data
3000 .variants
3001 .iter()
3002 .map(|variant| {
3003 let variant_name = &variant.ident;
3004 let variant_attrs: Vec<_> = variant
3005 .attrs
3006 .iter()
3007 .filter(|attr| !attr.path().is_ident("action"))
3008 .collect();
3009 let fields = &variant.fields;
3010
3011 let filtered_fields = match fields {
3013 syn::Fields::Named(named_fields) => {
3014 let filtered: Vec<_> = named_fields
3015 .named
3016 .iter()
3017 .map(|field| {
3018 let field_name = &field.ident;
3019 let field_type = &field.ty;
3020 let field_vis = &field.vis;
3021 let filtered_attrs: Vec<_> = field
3022 .attrs
3023 .iter()
3024 .filter(|attr| !attr.path().is_ident("action"))
3025 .collect();
3026 quote! {
3027 #(#filtered_attrs)*
3028 #field_vis #field_name: #field_type
3029 }
3030 })
3031 .collect();
3032 quote! { { #(#filtered,)* } }
3033 }
3034 syn::Fields::Unnamed(unnamed_fields) => {
3035 let types: Vec<_> = unnamed_fields
3036 .unnamed
3037 .iter()
3038 .map(|field| {
3039 let field_type = &field.ty;
3040 quote! { #field_type }
3041 })
3042 .collect();
3043 quote! { (#(#types),*) }
3044 }
3045 syn::Fields::Unit => quote! {},
3046 };
3047
3048 quote! {
3049 #(#variant_attrs)*
3050 #variant_name #filtered_fields
3051 }
3052 })
3053 .collect();
3054
3055 let vis = &input.vis;
3056 let generics = &input.generics;
3057
3058 let parsing_arms = generate_parsing_arms(&enum_data.variants, enum_name);
3060
3061 let tags_regex = generate_tags_regex(&enum_data.variants);
3063
3064 let expanded = quote! {
3065 #(#filtered_attrs)*
3067 #vis enum #enum_name #generics {
3068 #(#filtered_variants),*
3069 }
3070
3071 pub fn #function_name(#(#function_params),*) -> String {
3073 let mut env = minijinja::Environment::new();
3074 env.add_template("prompt", #prompt_template)
3075 .expect("Failed to parse intent prompt template");
3076
3077 let tmpl = env.get_template("prompt").unwrap();
3078
3079 let mut __template_context = std::collections::HashMap::new();
3080
3081 __template_context.insert("actions_doc".to_string(), minijinja::Value::from(#actions_doc));
3083
3084 #(#context_insertions)*
3086
3087 tmpl.render(&__template_context)
3088 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
3089 }
3090
3091 pub struct #extractor_name;
3093
3094 impl #extractor_name {
3095 fn parse_single_action(&self, text: &str) -> Option<#enum_name> {
3096 use ::quick_xml::events::Event;
3097 use ::quick_xml::Reader;
3098
3099 let mut actions = Vec::new();
3100 let mut reader = Reader::from_str(text);
3101 reader.config_mut().trim_text(true);
3102
3103 let mut buf = Vec::new();
3104
3105 loop {
3106 match reader.read_event_into(&mut buf) {
3107 Ok(Event::Start(e)) => {
3108 let owned_e = e.into_owned();
3109 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
3110 let is_empty = false;
3111
3112 #parsing_arms
3113 }
3114 Ok(Event::Empty(e)) => {
3115 let owned_e = e.into_owned();
3116 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
3117 let is_empty = true;
3118
3119 #parsing_arms
3120 }
3121 Ok(Event::Eof) => break,
3122 Err(_) => {
3123 break;
3125 }
3126 _ => {}
3127 }
3128 buf.clear();
3129 }
3130
3131 actions.into_iter().next()
3132 }
3133
3134 pub fn extract_actions(&self, text: &str) -> Result<Vec<#enum_name>, #crate_path::intent::IntentError> {
3135 use ::quick_xml::events::Event;
3136 use ::quick_xml::Reader;
3137
3138 let mut actions = Vec::new();
3139 let mut reader = Reader::from_str(text);
3140 reader.config_mut().trim_text(true);
3141
3142 let mut buf = Vec::new();
3143
3144 loop {
3145 match reader.read_event_into(&mut buf) {
3146 Ok(Event::Start(e)) => {
3147 let owned_e = e.into_owned();
3148 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
3149 let is_empty = false;
3150
3151 #parsing_arms
3152 }
3153 Ok(Event::Empty(e)) => {
3154 let owned_e = e.into_owned();
3155 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
3156 let is_empty = true;
3157
3158 #parsing_arms
3159 }
3160 Ok(Event::Eof) => break,
3161 Err(_) => {
3162 break;
3164 }
3165 _ => {}
3166 }
3167 buf.clear();
3168 }
3169
3170 Ok(actions)
3171 }
3172
3173 pub fn transform_actions<F>(&self, text: &str, mut transformer: F) -> String
3174 where
3175 F: FnMut(#enum_name) -> String,
3176 {
3177 use ::regex::Regex;
3178
3179 let regex_pattern = #tags_regex;
3180 if regex_pattern.is_empty() {
3181 return text.to_string();
3182 }
3183
3184 let re = Regex::new(®ex_pattern).unwrap_or_else(|e| {
3185 panic!("Failed to compile regex for action tags: {}", e);
3186 });
3187
3188 re.replace_all(text, |caps: &::regex::Captures| {
3189 let matched = caps.get(0).map(|m| m.as_str()).unwrap_or("");
3190
3191 if let Some(action) = self.parse_single_action(matched) {
3193 transformer(action)
3194 } else {
3195 matched.to_string()
3197 }
3198 }).to_string()
3199 }
3200
3201 pub fn strip_actions(&self, text: &str) -> String {
3202 self.transform_actions(text, |_| String::new())
3203 }
3204 }
3205 };
3206
3207 TokenStream::from(expanded)
3208}
3209
3210fn generate_parsing_arms(
3212 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
3213 enum_name: &syn::Ident,
3214) -> proc_macro2::TokenStream {
3215 let mut arms = Vec::new();
3216
3217 for variant in variants {
3218 let variant_name = &variant.ident;
3219 let action_attrs = parse_action_attrs(&variant.attrs);
3220
3221 if let Some(tag) = action_attrs.tag {
3222 match &variant.fields {
3223 syn::Fields::Unit => {
3224 arms.push(quote! {
3226 if &tag_name == #tag {
3227 actions.push(#enum_name::#variant_name);
3228 }
3229 });
3230 }
3231 syn::Fields::Unnamed(_fields) => {
3232 arms.push(quote! {
3234 if &tag_name == #tag && !is_empty {
3235 match reader.read_text(owned_e.name()) {
3237 Ok(text) => {
3238 actions.push(#enum_name::#variant_name(text.to_string()));
3239 }
3240 Err(_) => {
3241 actions.push(#enum_name::#variant_name(String::new()));
3243 }
3244 }
3245 }
3246 });
3247 }
3248 syn::Fields::Named(fields) => {
3249 let mut field_names = Vec::new();
3251 let mut has_inner_text_field = None;
3252
3253 for field in &fields.named {
3254 let field_name = field.ident.as_ref().unwrap();
3255 let field_attrs = parse_field_action_attrs(&field.attrs);
3256
3257 if field_attrs.is_attribute {
3258 field_names.push(field_name.clone());
3259 } else if field_attrs.is_inner_text {
3260 has_inner_text_field = Some(field_name.clone());
3261 }
3262 }
3263
3264 if let Some(inner_text_field) = has_inner_text_field {
3265 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
3268 quote! {
3269 let mut #field_name = String::new();
3270 for attr in owned_e.attributes() {
3271 if let Ok(attr) = attr {
3272 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
3273 #field_name = String::from_utf8_lossy(&attr.value).to_string();
3274 break;
3275 }
3276 }
3277 }
3278 }
3279 }).collect();
3280
3281 arms.push(quote! {
3282 if &tag_name == #tag {
3283 #(#attr_extractions)*
3284
3285 if is_empty {
3287 let #inner_text_field = String::new();
3288 actions.push(#enum_name::#variant_name {
3289 #(#field_names,)*
3290 #inner_text_field,
3291 });
3292 } else {
3293 match reader.read_text(owned_e.name()) {
3295 Ok(text) => {
3296 let #inner_text_field = text.to_string();
3297 actions.push(#enum_name::#variant_name {
3298 #(#field_names,)*
3299 #inner_text_field,
3300 });
3301 }
3302 Err(_) => {
3303 let #inner_text_field = String::new();
3305 actions.push(#enum_name::#variant_name {
3306 #(#field_names,)*
3307 #inner_text_field,
3308 });
3309 }
3310 }
3311 }
3312 }
3313 });
3314 } else {
3315 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
3317 quote! {
3318 let mut #field_name = String::new();
3319 for attr in owned_e.attributes() {
3320 if let Ok(attr) = attr {
3321 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
3322 #field_name = String::from_utf8_lossy(&attr.value).to_string();
3323 break;
3324 }
3325 }
3326 }
3327 }
3328 }).collect();
3329
3330 arms.push(quote! {
3331 if &tag_name == #tag {
3332 #(#attr_extractions)*
3333 actions.push(#enum_name::#variant_name {
3334 #(#field_names),*
3335 });
3336 }
3337 });
3338 }
3339 }
3340 }
3341 }
3342 }
3343
3344 quote! {
3345 #(#arms)*
3346 }
3347}
3348
3349#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
3351pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
3352 let input = parse_macro_input!(input as DeriveInput);
3353
3354 let found_crate =
3355 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3356 let crate_path = match found_crate {
3357 FoundCrate::Itself => {
3358 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3360 quote!(::#ident)
3361 }
3362 FoundCrate::Name(name) => {
3363 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3364 quote!(::#ident)
3365 }
3366 };
3367
3368 let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
3370
3371 let struct_name = &input.ident;
3372 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
3373
3374 let placeholders = parse_template_placeholders_with_mode(&template);
3376
3377 let mut converted_template = template.clone();
3379 let mut context_fields = Vec::new();
3380
3381 let fields = match &input.data {
3383 Data::Struct(data_struct) => match &data_struct.fields {
3384 syn::Fields::Named(fields) => &fields.named,
3385 _ => panic!("ToPromptFor is only supported for structs with named fields"),
3386 },
3387 _ => panic!("ToPromptFor is only supported for structs"),
3388 };
3389
3390 let has_mode_support = input.attrs.iter().any(|attr| {
3392 if attr.path().is_ident("prompt")
3393 && let Ok(metas) =
3394 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
3395 {
3396 for meta in metas {
3397 if let Meta::NameValue(nv) = meta
3398 && nv.path.is_ident("mode")
3399 {
3400 return true;
3401 }
3402 }
3403 }
3404 false
3405 });
3406
3407 for (placeholder_name, mode_opt) in &placeholders {
3409 if placeholder_name == "self" {
3410 if let Some(specific_mode) = mode_opt {
3411 let unique_key = format!("self__{}", specific_mode);
3413
3414 let pattern = format!("{{{{ self:{} }}}}", specific_mode);
3416 let replacement = format!("{{{{ {} }}}}", unique_key);
3417 converted_template = converted_template.replace(&pattern, &replacement);
3418
3419 context_fields.push(quote! {
3421 context.insert(
3422 #unique_key.to_string(),
3423 minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
3424 );
3425 });
3426 } else {
3427 if has_mode_support {
3430 context_fields.push(quote! {
3432 context.insert(
3433 "self".to_string(),
3434 minijinja::Value::from(self.to_prompt_with_mode(mode))
3435 );
3436 });
3437 } else {
3438 context_fields.push(quote! {
3440 context.insert(
3441 "self".to_string(),
3442 minijinja::Value::from(self.to_prompt())
3443 );
3444 });
3445 }
3446 }
3447 } else {
3448 let field_exists = fields.iter().any(|f| {
3451 f.ident
3452 .as_ref()
3453 .is_some_and(|ident| ident == placeholder_name)
3454 });
3455
3456 if field_exists {
3457 let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
3458
3459 context_fields.push(quote! {
3463 context.insert(
3464 #placeholder_name.to_string(),
3465 minijinja::Value::from_serialize(&self.#field_ident)
3466 );
3467 });
3468 }
3469 }
3471 }
3472
3473 let expanded = quote! {
3474 impl #impl_generics #crate_path::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
3475 where
3476 #target_type: serde::Serialize,
3477 {
3478 fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
3479 let mut env = minijinja::Environment::new();
3481 env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
3482 panic!("Failed to parse template: {}", e)
3483 });
3484
3485 let tmpl = env.get_template("prompt").unwrap();
3486
3487 let mut context = std::collections::HashMap::new();
3489 context.insert(
3491 "self".to_string(),
3492 minijinja::Value::from_serialize(self)
3493 );
3494 context.insert(
3496 "target".to_string(),
3497 minijinja::Value::from_serialize(target)
3498 );
3499 #(#context_fields)*
3500
3501 tmpl.render(context).unwrap_or_else(|e| {
3503 format!("Failed to render prompt: {}", e)
3504 })
3505 }
3506 }
3507 };
3508
3509 TokenStream::from(expanded)
3510}
3511
3512struct AgentAttrs {
3518 expertise: Option<String>,
3519 output: Option<syn::Type>,
3520 backend: Option<String>,
3521 model: Option<String>,
3522 inner: Option<String>,
3523 default_inner: Option<String>,
3524 max_retries: Option<u32>,
3525 profile: Option<String>,
3526}
3527
3528impl Parse for AgentAttrs {
3529 fn parse(input: ParseStream) -> syn::Result<Self> {
3530 let mut expertise = None;
3531 let mut output = None;
3532 let mut backend = None;
3533 let mut model = None;
3534 let mut inner = None;
3535 let mut default_inner = None;
3536 let mut max_retries = None;
3537 let mut profile = None;
3538
3539 let pairs = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
3540
3541 for meta in pairs {
3542 match meta {
3543 Meta::NameValue(nv) if nv.path.is_ident("expertise") => {
3544 if let syn::Expr::Lit(syn::ExprLit {
3545 lit: syn::Lit::Str(lit_str),
3546 ..
3547 }) = &nv.value
3548 {
3549 expertise = Some(lit_str.value());
3550 }
3551 }
3552 Meta::NameValue(nv) if nv.path.is_ident("output") => {
3553 if let syn::Expr::Lit(syn::ExprLit {
3554 lit: syn::Lit::Str(lit_str),
3555 ..
3556 }) = &nv.value
3557 {
3558 let ty: syn::Type = syn::parse_str(&lit_str.value())?;
3559 output = Some(ty);
3560 }
3561 }
3562 Meta::NameValue(nv) if nv.path.is_ident("backend") => {
3563 if let syn::Expr::Lit(syn::ExprLit {
3564 lit: syn::Lit::Str(lit_str),
3565 ..
3566 }) = &nv.value
3567 {
3568 backend = Some(lit_str.value());
3569 }
3570 }
3571 Meta::NameValue(nv) if nv.path.is_ident("model") => {
3572 if let syn::Expr::Lit(syn::ExprLit {
3573 lit: syn::Lit::Str(lit_str),
3574 ..
3575 }) = &nv.value
3576 {
3577 model = Some(lit_str.value());
3578 }
3579 }
3580 Meta::NameValue(nv) if nv.path.is_ident("inner") => {
3581 if let syn::Expr::Lit(syn::ExprLit {
3582 lit: syn::Lit::Str(lit_str),
3583 ..
3584 }) = &nv.value
3585 {
3586 inner = Some(lit_str.value());
3587 }
3588 }
3589 Meta::NameValue(nv) if nv.path.is_ident("default_inner") => {
3590 if let syn::Expr::Lit(syn::ExprLit {
3591 lit: syn::Lit::Str(lit_str),
3592 ..
3593 }) = &nv.value
3594 {
3595 default_inner = Some(lit_str.value());
3596 }
3597 }
3598 Meta::NameValue(nv) if nv.path.is_ident("max_retries") => {
3599 if let syn::Expr::Lit(syn::ExprLit {
3600 lit: syn::Lit::Int(lit_int),
3601 ..
3602 }) = &nv.value
3603 {
3604 max_retries = Some(lit_int.base10_parse()?);
3605 }
3606 }
3607 Meta::NameValue(nv) if nv.path.is_ident("profile") => {
3608 if let syn::Expr::Lit(syn::ExprLit {
3609 lit: syn::Lit::Str(lit_str),
3610 ..
3611 }) = &nv.value
3612 {
3613 profile = Some(lit_str.value());
3614 }
3615 }
3616 _ => {}
3617 }
3618 }
3619
3620 Ok(AgentAttrs {
3621 expertise,
3622 output,
3623 backend,
3624 model,
3625 inner,
3626 default_inner,
3627 max_retries,
3628 profile,
3629 })
3630 }
3631}
3632
3633fn parse_agent_attrs(attrs: &[syn::Attribute]) -> syn::Result<AgentAttrs> {
3635 for attr in attrs {
3636 if attr.path().is_ident("agent") {
3637 return attr.parse_args::<AgentAttrs>();
3638 }
3639 }
3640
3641 Ok(AgentAttrs {
3642 expertise: None,
3643 output: None,
3644 backend: None,
3645 model: None,
3646 inner: None,
3647 default_inner: None,
3648 max_retries: None,
3649 profile: None,
3650 })
3651}
3652
3653fn generate_backend_constructors(
3655 struct_name: &syn::Ident,
3656 backend: &str,
3657 _model: Option<&str>,
3658 _profile: Option<&str>,
3659 crate_path: &proc_macro2::TokenStream,
3660) -> proc_macro2::TokenStream {
3661 match backend {
3662 "claude" => {
3663 quote! {
3664 impl #struct_name {
3665 pub fn with_claude() -> Self {
3667 Self::new(#crate_path::agent::impls::ClaudeCodeAgent::new())
3668 }
3669
3670 pub fn with_claude_model(model: &str) -> Self {
3672 Self::new(
3673 #crate_path::agent::impls::ClaudeCodeAgent::new()
3674 .with_model_str(model)
3675 )
3676 }
3677 }
3678 }
3679 }
3680 "gemini" => {
3681 quote! {
3682 impl #struct_name {
3683 pub fn with_gemini() -> Self {
3685 Self::new(#crate_path::agent::impls::GeminiAgent::new())
3686 }
3687
3688 pub fn with_gemini_model(model: &str) -> Self {
3690 Self::new(
3691 #crate_path::agent::impls::GeminiAgent::new()
3692 .with_model_str(model)
3693 )
3694 }
3695 }
3696 }
3697 }
3698 _ => quote! {},
3699 }
3700}
3701
3702fn generate_default_impl(
3704 struct_name: &syn::Ident,
3705 backend: &str,
3706 model: Option<&str>,
3707 profile: Option<&str>,
3708 crate_path: &proc_macro2::TokenStream,
3709) -> proc_macro2::TokenStream {
3710 let profile_expr = if let Some(profile_str) = profile {
3712 match profile_str.to_lowercase().as_str() {
3713 "creative" => quote! { #crate_path::agent::ExecutionProfile::Creative },
3714 "balanced" => quote! { #crate_path::agent::ExecutionProfile::Balanced },
3715 "deterministic" => quote! { #crate_path::agent::ExecutionProfile::Deterministic },
3716 _ => quote! { #crate_path::agent::ExecutionProfile::Balanced }, }
3718 } else {
3719 quote! { #crate_path::agent::ExecutionProfile::default() }
3720 };
3721
3722 let agent_init = match backend {
3723 "gemini" => {
3724 let mut builder = quote! { #crate_path::agent::impls::GeminiAgent::new() };
3725
3726 if let Some(model_str) = model {
3727 builder = quote! { #builder.with_model_str(#model_str) };
3728 }
3729
3730 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3731 builder
3732 }
3733 _ => {
3734 let mut builder = quote! { #crate_path::agent::impls::ClaudeCodeAgent::new() };
3736
3737 if let Some(model_str) = model {
3738 builder = quote! { #builder.with_model_str(#model_str) };
3739 }
3740
3741 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3742 builder
3743 }
3744 };
3745
3746 quote! {
3747 impl Default for #struct_name {
3748 fn default() -> Self {
3749 Self::new(#agent_init)
3750 }
3751 }
3752 }
3753}
3754
3755#[proc_macro_derive(Agent, attributes(agent))]
3764pub fn derive_agent(input: TokenStream) -> TokenStream {
3765 let input = parse_macro_input!(input as DeriveInput);
3766 let struct_name = &input.ident;
3767
3768 let agent_attrs = match parse_agent_attrs(&input.attrs) {
3770 Ok(attrs) => attrs,
3771 Err(e) => return e.to_compile_error().into(),
3772 };
3773
3774 let expertise = agent_attrs
3775 .expertise
3776 .unwrap_or_else(|| String::from("general AI assistant"));
3777 let output_type = agent_attrs
3778 .output
3779 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3780 let backend = agent_attrs
3781 .backend
3782 .unwrap_or_else(|| String::from("claude"));
3783 let model = agent_attrs.model;
3784 let _profile = agent_attrs.profile; let max_retries = agent_attrs.max_retries.unwrap_or(3); let found_crate =
3789 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3790 let crate_path = match found_crate {
3791 FoundCrate::Itself => {
3792 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3794 quote!(::#ident)
3795 }
3796 FoundCrate::Name(name) => {
3797 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3798 quote!(::#ident)
3799 }
3800 };
3801
3802 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
3803
3804 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3806 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3807
3808 let enhanced_expertise = if is_string_output {
3810 quote! { #expertise }
3812 } else {
3813 let type_name = quote!(#output_type).to_string();
3815 quote! {
3816 {
3817 use std::sync::OnceLock;
3818 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3819
3820 EXPERTISE_CACHE.get_or_init(|| {
3821 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3823
3824 if schema.is_empty() {
3825 format!(
3827 concat!(
3828 #expertise,
3829 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3830 "Do not include any text outside the JSON object."
3831 ),
3832 #type_name
3833 )
3834 } else {
3835 format!(
3837 concat!(
3838 #expertise,
3839 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3840 ),
3841 schema
3842 )
3843 }
3844 }).as_str()
3845 }
3846 }
3847 };
3848
3849 let agent_init = match backend.as_str() {
3851 "gemini" => {
3852 if let Some(model_str) = model {
3853 quote! {
3854 use #crate_path::agent::impls::GeminiAgent;
3855 let agent = GeminiAgent::new().with_model_str(#model_str);
3856 }
3857 } else {
3858 quote! {
3859 use #crate_path::agent::impls::GeminiAgent;
3860 let agent = GeminiAgent::new();
3861 }
3862 }
3863 }
3864 "claude" => {
3865 if let Some(model_str) = model {
3866 quote! {
3867 use #crate_path::agent::impls::ClaudeCodeAgent;
3868 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3869 }
3870 } else {
3871 quote! {
3872 use #crate_path::agent::impls::ClaudeCodeAgent;
3873 let agent = ClaudeCodeAgent::new();
3874 }
3875 }
3876 }
3877 _ => {
3878 if let Some(model_str) = model {
3880 quote! {
3881 use #crate_path::agent::impls::ClaudeCodeAgent;
3882 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3883 }
3884 } else {
3885 quote! {
3886 use #crate_path::agent::impls::ClaudeCodeAgent;
3887 let agent = ClaudeCodeAgent::new();
3888 }
3889 }
3890 }
3891 };
3892
3893 let expanded = quote! {
3894 #[async_trait::async_trait]
3895 impl #impl_generics #crate_path::agent::Agent for #struct_name #ty_generics #where_clause {
3896 type Output = #output_type;
3897
3898 fn expertise(&self) -> &str {
3899 #enhanced_expertise
3900 }
3901
3902 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3903 #agent_init
3905
3906 let agent_ref = &agent;
3908 #crate_path::agent::retry::retry_execution(
3909 #max_retries,
3910 &intent,
3911 move |payload| {
3912 let payload = payload.clone();
3913 async move {
3914 let response = agent_ref.execute(payload).await?;
3916
3917 let json_str = #crate_path::extract_json(&response)
3919 .map_err(|e| #crate_path::agent::AgentError::ParseError {
3920 message: format!("Failed to extract JSON: {}", e),
3921 reason: #crate_path::agent::error::ParseErrorReason::MarkdownExtractionFailed,
3922 })?;
3923
3924 serde_json::from_str::<Self::Output>(&json_str)
3926 .map_err(|e| {
3927 let reason = if e.is_eof() {
3929 #crate_path::agent::error::ParseErrorReason::UnexpectedEof
3930 } else if e.is_syntax() {
3931 #crate_path::agent::error::ParseErrorReason::InvalidJson
3932 } else {
3933 #crate_path::agent::error::ParseErrorReason::SchemaMismatch
3934 };
3935
3936 #crate_path::agent::AgentError::ParseError {
3937 message: format!("Failed to parse JSON: {}", e),
3938 reason,
3939 }
3940 })
3941 }
3942 }
3943 ).await
3944 }
3945
3946 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3947 #agent_init
3949 agent.is_available().await
3950 }
3951 }
3952 };
3953
3954 TokenStream::from(expanded)
3955}
3956
3957#[proc_macro_attribute]
3972pub fn agent(attr: TokenStream, item: TokenStream) -> TokenStream {
3973 let agent_attrs = match syn::parse::<AgentAttrs>(attr) {
3975 Ok(attrs) => attrs,
3976 Err(e) => return e.to_compile_error().into(),
3977 };
3978
3979 let input = parse_macro_input!(item as DeriveInput);
3981 let struct_name = &input.ident;
3982 let vis = &input.vis;
3983
3984 let expertise = agent_attrs
3985 .expertise
3986 .unwrap_or_else(|| String::from("general AI assistant"));
3987 let output_type = agent_attrs
3988 .output
3989 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3990 let backend = agent_attrs
3991 .backend
3992 .unwrap_or_else(|| String::from("claude"));
3993 let model = agent_attrs.model;
3994 let profile = agent_attrs.profile;
3995
3996 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3998 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3999
4000 let found_crate =
4002 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
4003 let crate_path = match found_crate {
4004 FoundCrate::Itself => {
4005 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
4006 quote!(::#ident)
4007 }
4008 FoundCrate::Name(name) => {
4009 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
4010 quote!(::#ident)
4011 }
4012 };
4013
4014 let inner_generic_name = agent_attrs.inner.unwrap_or_else(|| String::from("A"));
4016 let inner_generic_ident = syn::Ident::new(&inner_generic_name, proc_macro2::Span::call_site());
4017
4018 let default_agent_type = if let Some(ref custom_type) = agent_attrs.default_inner {
4020 let type_path: syn::Type =
4022 syn::parse_str(custom_type).expect("default_inner must be a valid type path");
4023 quote! { #type_path }
4024 } else {
4025 match backend.as_str() {
4027 "gemini" => quote! { #crate_path::agent::impls::GeminiAgent },
4028 _ => quote! { #crate_path::agent::impls::ClaudeCodeAgent },
4029 }
4030 };
4031
4032 let struct_def = quote! {
4034 #vis struct #struct_name<#inner_generic_ident = #default_agent_type> {
4035 inner: #inner_generic_ident,
4036 }
4037 };
4038
4039 let constructors = quote! {
4041 impl<#inner_generic_ident> #struct_name<#inner_generic_ident> {
4042 pub fn new(inner: #inner_generic_ident) -> Self {
4044 Self { inner }
4045 }
4046 }
4047 };
4048
4049 let (backend_constructors, default_impl) = if agent_attrs.default_inner.is_some() {
4051 let default_impl = quote! {
4053 impl Default for #struct_name {
4054 fn default() -> Self {
4055 Self {
4056 inner: <#default_agent_type as Default>::default(),
4057 }
4058 }
4059 }
4060 };
4061 (quote! {}, default_impl)
4062 } else {
4063 let backend_constructors = generate_backend_constructors(
4065 struct_name,
4066 &backend,
4067 model.as_deref(),
4068 profile.as_deref(),
4069 &crate_path,
4070 );
4071 let default_impl = generate_default_impl(
4072 struct_name,
4073 &backend,
4074 model.as_deref(),
4075 profile.as_deref(),
4076 &crate_path,
4077 );
4078 (backend_constructors, default_impl)
4079 };
4080
4081 let enhanced_expertise = if is_string_output {
4083 quote! { #expertise }
4085 } else {
4086 let type_name = quote!(#output_type).to_string();
4088 quote! {
4089 {
4090 use std::sync::OnceLock;
4091 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
4092
4093 EXPERTISE_CACHE.get_or_init(|| {
4094 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
4096
4097 if schema.is_empty() {
4098 format!(
4100 concat!(
4101 #expertise,
4102 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
4103 "Do not include any text outside the JSON object."
4104 ),
4105 #type_name
4106 )
4107 } else {
4108 format!(
4110 concat!(
4111 #expertise,
4112 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
4113 ),
4114 schema
4115 )
4116 }
4117 }).as_str()
4118 }
4119 }
4120 };
4121
4122 let agent_impl = quote! {
4124 #[async_trait::async_trait]
4125 impl<#inner_generic_ident> #crate_path::agent::Agent for #struct_name<#inner_generic_ident>
4126 where
4127 #inner_generic_ident: #crate_path::agent::Agent<Output = String>,
4128 {
4129 type Output = #output_type;
4130
4131 fn expertise(&self) -> &str {
4132 #enhanced_expertise
4133 }
4134
4135 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
4136 let enhanced_payload = intent.prepend_text(self.expertise());
4138
4139 let response = self.inner.execute(enhanced_payload).await?;
4141
4142 let json_str = #crate_path::extract_json(&response)
4144 .map_err(|e| #crate_path::agent::AgentError::ParseError {
4145 message: e.to_string(),
4146 reason: #crate_path::agent::error::ParseErrorReason::MarkdownExtractionFailed,
4147 })?;
4148
4149 serde_json::from_str(&json_str).map_err(|e| {
4151 let reason = if e.is_eof() {
4152 #crate_path::agent::error::ParseErrorReason::UnexpectedEof
4153 } else if e.is_syntax() {
4154 #crate_path::agent::error::ParseErrorReason::InvalidJson
4155 } else {
4156 #crate_path::agent::error::ParseErrorReason::SchemaMismatch
4157 };
4158 #crate_path::agent::AgentError::ParseError {
4159 message: e.to_string(),
4160 reason,
4161 }
4162 })
4163 }
4164
4165 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
4166 self.inner.is_available().await
4167 }
4168 }
4169 };
4170
4171 let expanded = quote! {
4172 #struct_def
4173 #constructors
4174 #backend_constructors
4175 #default_impl
4176 #agent_impl
4177 };
4178
4179 TokenStream::from(expanded)
4180}
4181
4182#[proc_macro_derive(TypeMarker)]
4204pub fn derive_type_marker(input: TokenStream) -> TokenStream {
4205 let input = parse_macro_input!(input as DeriveInput);
4206 let struct_name = &input.ident;
4207 let type_name_str = struct_name.to_string();
4208
4209 let found_crate =
4211 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
4212 let crate_path = match found_crate {
4213 FoundCrate::Itself => {
4214 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
4215 quote!(::#ident)
4216 }
4217 FoundCrate::Name(name) => {
4218 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
4219 quote!(::#ident)
4220 }
4221 };
4222
4223 let expanded = quote! {
4224 impl #crate_path::orchestrator::TypeMarker for #struct_name {
4225 const TYPE_NAME: &'static str = #type_name_str;
4226 }
4227 };
4228
4229 TokenStream::from(expanded)
4230}
4231
4232#[proc_macro_attribute]
4268pub fn type_marker(_attr: TokenStream, item: TokenStream) -> TokenStream {
4269 let input = parse_macro_input!(item as syn::DeriveInput);
4270 let struct_name = &input.ident;
4271 let vis = &input.vis;
4272 let type_name_str = struct_name.to_string();
4273
4274 let default_fn_name = syn::Ident::new(
4276 &format!("default_{}_type", to_snake_case(&type_name_str)),
4277 struct_name.span(),
4278 );
4279
4280 let found_crate =
4282 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
4283 let crate_path = match found_crate {
4284 FoundCrate::Itself => {
4285 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
4286 quote!(::#ident)
4287 }
4288 FoundCrate::Name(name) => {
4289 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
4290 quote!(::#ident)
4291 }
4292 };
4293
4294 let fields = match &input.data {
4296 syn::Data::Struct(data_struct) => match &data_struct.fields {
4297 syn::Fields::Named(fields) => &fields.named,
4298 _ => {
4299 return syn::Error::new_spanned(
4300 struct_name,
4301 "type_marker only works with structs with named fields",
4302 )
4303 .to_compile_error()
4304 .into();
4305 }
4306 },
4307 _ => {
4308 return syn::Error::new_spanned(struct_name, "type_marker only works with structs")
4309 .to_compile_error()
4310 .into();
4311 }
4312 };
4313
4314 let mut new_fields = vec![];
4316
4317 let default_fn_name_str = default_fn_name.to_string();
4319 let default_fn_name_lit = syn::LitStr::new(&default_fn_name_str, default_fn_name.span());
4320
4321 new_fields.push(quote! {
4326 #[serde(default = #default_fn_name_lit)]
4327 __type: String
4328 });
4329
4330 for field in fields {
4332 new_fields.push(quote! { #field });
4333 }
4334
4335 let attrs = &input.attrs;
4337 let generics = &input.generics;
4338
4339 let expanded = quote! {
4340 fn #default_fn_name() -> String {
4342 #type_name_str.to_string()
4343 }
4344
4345 #(#attrs)*
4347 #vis struct #struct_name #generics {
4348 #(#new_fields),*
4349 }
4350
4351 impl #crate_path::orchestrator::TypeMarker for #struct_name {
4353 const TYPE_NAME: &'static str = #type_name_str;
4354 }
4355 };
4356
4357 TokenStream::from(expanded)
4358}