1use proc_macro::TokenStream;
2use proc_macro_crate::{FoundCrate, crate_name};
3use quote::quote;
4use regex::Regex;
5use syn::{
6 Data, DeriveInput, Meta, Token,
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 punctuated::Punctuated,
10};
11
12fn parse_template_placeholders_with_mode(template: &str) -> Vec<(String, Option<String>)> {
15 let mut placeholders = Vec::new();
16 let mut seen_fields = std::collections::HashSet::new();
17
18 let mode_pattern = Regex::new(r"\{\{\s*(\w+)\s*:\s*(\w+)\s*\}\}").unwrap();
20 for cap in mode_pattern.captures_iter(template) {
21 let field_name = cap[1].to_string();
22 let mode = cap[2].to_string();
23 placeholders.push((field_name.clone(), Some(mode)));
24 seen_fields.insert(field_name);
25 }
26
27 let standard_pattern = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
29 for cap in standard_pattern.captures_iter(template) {
30 let field_name = cap[1].to_string();
31 if !seen_fields.contains(&field_name) {
33 placeholders.push((field_name, None));
34 }
35 }
36
37 placeholders
38}
39
40fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
42 attrs
43 .iter()
44 .filter_map(|attr| {
45 if attr.path().is_ident("doc")
46 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
47 && let syn::Expr::Lit(syn::ExprLit {
48 lit: syn::Lit::Str(lit_str),
49 ..
50 }) = &meta_name_value.value
51 {
52 return Some(lit_str.value());
53 }
54 None
55 })
56 .map(|s| s.trim().to_string())
57 .collect::<Vec<_>>()
58 .join(" ")
59}
60
61fn generate_example_only_parts(
63 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
64 has_default: bool,
65 crate_path: &proc_macro2::TokenStream,
66) -> proc_macro2::TokenStream {
67 let mut field_values = Vec::new();
68
69 for field in fields.iter() {
70 let field_name = field.ident.as_ref().unwrap();
71 let field_name_str = field_name.to_string();
72 let attrs = parse_field_prompt_attrs(&field.attrs);
73
74 if field_name_str == "__type" {
78 continue;
79 }
80
81 if attrs.skip {
83 continue;
84 }
85
86 if let Some(example) = attrs.example {
88 field_values.push(quote! {
90 json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
91 });
92 } else if has_default {
93 field_values.push(quote! {
95 let default_value = serde_json::to_value(&default_instance.#field_name)
96 .unwrap_or(serde_json::Value::Null);
97 json_obj.insert(#field_name_str.to_string(), default_value);
98 });
99 } else {
100 field_values.push(quote! {
102 let value = serde_json::to_value(&self.#field_name)
103 .unwrap_or(serde_json::Value::Null);
104 json_obj.insert(#field_name_str.to_string(), value);
105 });
106 }
107 }
108
109 if has_default {
110 quote! {
111 {
112 let default_instance = Self::default();
113 let mut json_obj = serde_json::Map::new();
114 #(#field_values)*
115 let json_value = serde_json::Value::Object(json_obj);
116 let json_str = serde_json::to_string_pretty(&json_value)
117 .unwrap_or_else(|_| "{}".to_string());
118 vec![#crate_path::prompt::PromptPart::Text(json_str)]
119 }
120 }
121 } else {
122 quote! {
123 {
124 let mut json_obj = serde_json::Map::new();
125 #(#field_values)*
126 let json_value = serde_json::Value::Object(json_obj);
127 let json_str = serde_json::to_string_pretty(&json_value)
128 .unwrap_or_else(|_| "{}".to_string());
129 vec![#crate_path::prompt::PromptPart::Text(json_str)]
130 }
131 }
132 }
133}
134
135fn generate_schema_only_parts(
137 struct_name: &str,
138 struct_docs: &str,
139 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
140 crate_path: &proc_macro2::TokenStream,
141 _has_type_marker: bool,
142) -> proc_macro2::TokenStream {
143 let mut field_schema_parts = vec![];
144 let mut nested_type_collectors = vec![];
145
146 for field in fields.iter() {
148 let field_name = field.ident.as_ref().unwrap();
149 let field_name_str = field_name.to_string();
150 let attrs = parse_field_prompt_attrs(&field.attrs);
151
152 if field_name_str == "__type" {
156 continue;
157 }
158
159 if attrs.skip {
161 continue;
162 }
163
164 let field_docs = extract_doc_comments(&field.attrs);
166
167 let (is_vec, inner_type) = extract_vec_inner_type(&field.ty);
169
170 if is_vec {
171 let comment = if !field_docs.is_empty() {
174 format!(" // {}", field_docs)
175 } else {
176 String::new()
177 };
178
179 field_schema_parts.push(quote! {
180 {
181 let type_name = stringify!(#inner_type);
182 format!(" {}: {}[];{}", #field_name_str, type_name, #comment)
183 }
184 });
185
186 if let Some(inner) = inner_type
188 && !is_primitive_type(inner)
189 {
190 nested_type_collectors.push(quote! {
191 <#inner as #crate_path::prompt::ToPrompt>::prompt_schema()
192 });
193 }
194 } else {
195 let field_type = &field.ty;
197 let is_primitive = is_primitive_type(field_type);
198
199 if !is_primitive {
200 let comment = if !field_docs.is_empty() {
203 format!(" // {}", field_docs)
204 } else {
205 String::new()
206 };
207
208 field_schema_parts.push(quote! {
209 {
210 let type_name = stringify!(#field_type);
211 format!(" {}: {};{}", #field_name_str, type_name, #comment)
212 }
213 });
214
215 nested_type_collectors.push(quote! {
217 <#field_type as #crate_path::prompt::ToPrompt>::prompt_schema()
218 });
219 } else {
220 let type_str = format_type_for_schema(&field.ty);
223 let comment = if !field_docs.is_empty() {
224 format!(" // {}", field_docs)
225 } else {
226 String::new()
227 };
228
229 field_schema_parts.push(quote! {
230 format!(" {}: {};{}", #field_name_str, #type_str, #comment)
231 });
232 }
233 }
234 }
235
236 let mut header_lines = Vec::new();
251
252 if !struct_docs.is_empty() {
254 header_lines.push("/**".to_string());
255 header_lines.push(format!(" * {}", struct_docs));
256 header_lines.push(" */".to_string());
257 }
258
259 header_lines.push(format!("type {} = {{", struct_name));
261
262 quote! {
263 {
264 let mut all_lines: Vec<String> = Vec::new();
265
266 let nested_schemas: Vec<String> = vec![#(#nested_type_collectors),*];
268 let mut seen_types = std::collections::HashSet::<String>::new();
269
270 for schema in nested_schemas {
271 if !schema.is_empty() {
272 if seen_types.insert(schema.clone()) {
274 all_lines.push(schema);
275 all_lines.push(String::new()); }
277 }
278 }
279
280 let mut lines: Vec<String> = Vec::new();
282 #(lines.push(#header_lines.to_string());)*
283 #(lines.push(#field_schema_parts);)*
284 lines.push("}".to_string());
285 all_lines.push(lines.join("\n"));
286
287 vec![#crate_path::prompt::PromptPart::Text(all_lines.join("\n"))]
288 }
289 }
290}
291
292fn extract_vec_inner_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
294 if let syn::Type::Path(type_path) = ty
295 && let Some(last_segment) = type_path.path.segments.last()
296 && last_segment.ident == "Vec"
297 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
298 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
299 {
300 return (true, Some(inner_type));
301 }
302 (false, None)
303}
304
305fn is_primitive_type(ty: &syn::Type) -> bool {
307 if let syn::Type::Path(type_path) = ty
308 && let Some(last_segment) = type_path.path.segments.last()
309 {
310 let type_name = last_segment.ident.to_string();
311 matches!(
312 type_name.as_str(),
313 "String"
314 | "str"
315 | "i8"
316 | "i16"
317 | "i32"
318 | "i64"
319 | "i128"
320 | "isize"
321 | "u8"
322 | "u16"
323 | "u32"
324 | "u64"
325 | "u128"
326 | "usize"
327 | "f32"
328 | "f64"
329 | "bool"
330 | "Vec"
331 | "Option"
332 | "HashMap"
333 | "BTreeMap"
334 | "HashSet"
335 | "BTreeSet"
336 )
337 } else {
338 true
340 }
341}
342
343fn format_type_for_schema(ty: &syn::Type) -> String {
345 match ty {
347 syn::Type::Path(type_path) => {
348 let path = &type_path.path;
349 if let Some(last_segment) = path.segments.last() {
350 let type_name = last_segment.ident.to_string();
351
352 if type_name == "Option"
354 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
355 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
356 {
357 return format!("{} | null", format_type_for_schema(inner_type));
358 }
359
360 match type_name.as_str() {
362 "String" | "str" => "string".to_string(),
363 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
364 | "u64" | "u128" | "usize" => "number".to_string(),
365 "f32" | "f64" => "number".to_string(),
366 "bool" => "boolean".to_string(),
367 "Vec" => {
368 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
369 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
370 {
371 return format!("{}[]", format_type_for_schema(inner_type));
372 }
373 "array".to_string()
374 }
375 _ => type_name,
377 }
378 } else {
379 "unknown".to_string()
380 }
381 }
382 _ => "unknown".to_string(),
383 }
384}
385
386#[derive(Default)]
388struct PromptAttributes {
389 skip: bool,
390 rename: Option<String>,
391 description: Option<String>,
392}
393
394fn parse_prompt_attributes(attrs: &[syn::Attribute]) -> PromptAttributes {
397 let mut result = PromptAttributes::default();
398
399 for attr in attrs {
400 if attr.path().is_ident("prompt") {
401 if let Ok(meta_list) = attr.meta.require_list() {
403 if let Ok(metas) =
405 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
406 {
407 for meta in metas {
408 if let Meta::NameValue(nv) = meta {
409 if nv.path.is_ident("rename") {
410 if let syn::Expr::Lit(syn::ExprLit {
411 lit: syn::Lit::Str(lit_str),
412 ..
413 }) = nv.value
414 {
415 result.rename = Some(lit_str.value());
416 }
417 } else if nv.path.is_ident("description")
418 && let syn::Expr::Lit(syn::ExprLit {
419 lit: syn::Lit::Str(lit_str),
420 ..
421 }) = nv.value
422 {
423 result.description = Some(lit_str.value());
424 }
425 } else if let Meta::Path(path) = meta
426 && path.is_ident("skip")
427 {
428 result.skip = true;
429 }
430 }
431 }
432
433 let tokens_str = meta_list.tokens.to_string();
435 if tokens_str == "skip" {
436 result.skip = true;
437 }
438 }
439
440 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
442 result.description = Some(lit_str.value());
443 }
444 }
445 }
446 result
447}
448
449fn generate_example_value_for_type(type_str: &str) -> String {
451 match type_str {
452 "string" => "\"example\"".to_string(),
453 "number" => "0".to_string(),
454 "boolean" => "false".to_string(),
455 s if s.ends_with("[]") => "[]".to_string(),
456 s if s.contains("|") => {
457 let first_type = s.split('|').next().unwrap().trim();
459 generate_example_value_for_type(first_type)
460 }
461 _ => "null".to_string(),
462 }
463}
464
465fn parse_serde_variant_rename(attrs: &[syn::Attribute]) -> Option<String> {
467 for attr in attrs {
468 if attr.path().is_ident("serde")
469 && let Ok(meta_list) = attr.meta.require_list()
470 && let Ok(metas) =
471 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
472 {
473 for meta in metas {
474 if let Meta::NameValue(nv) = meta
475 && nv.path.is_ident("rename")
476 && let syn::Expr::Lit(syn::ExprLit {
477 lit: syn::Lit::Str(lit_str),
478 ..
479 }) = nv.value
480 {
481 return Some(lit_str.value());
482 }
483 }
484 }
485 }
486 None
487}
488
489#[derive(Debug, Clone, Copy, PartialEq, Eq)]
491enum RenameRule {
492 #[allow(dead_code)]
493 None,
494 LowerCase,
495 UpperCase,
496 PascalCase,
497 CamelCase,
498 SnakeCase,
499 ScreamingSnakeCase,
500 KebabCase,
501 ScreamingKebabCase,
502}
503
504impl RenameRule {
505 fn from_str(s: &str) -> Option<Self> {
507 match s {
508 "lowercase" => Some(Self::LowerCase),
509 "UPPERCASE" => Some(Self::UpperCase),
510 "PascalCase" => Some(Self::PascalCase),
511 "camelCase" => Some(Self::CamelCase),
512 "snake_case" => Some(Self::SnakeCase),
513 "SCREAMING_SNAKE_CASE" => Some(Self::ScreamingSnakeCase),
514 "kebab-case" => Some(Self::KebabCase),
515 "SCREAMING-KEBAB-CASE" => Some(Self::ScreamingKebabCase),
516 _ => None,
517 }
518 }
519
520 fn apply(&self, name: &str) -> String {
522 match self {
523 Self::None => name.to_string(),
524 Self::LowerCase => name.to_lowercase(),
525 Self::UpperCase => name.to_uppercase(),
526 Self::PascalCase => name.to_string(), Self::CamelCase => {
528 let mut chars = name.chars();
530 match chars.next() {
531 None => String::new(),
532 Some(first) => first.to_lowercase().chain(chars).collect(),
533 }
534 }
535 Self::SnakeCase => {
536 let mut result = String::new();
538 for (i, ch) in name.chars().enumerate() {
539 if ch.is_uppercase() && i > 0 {
540 result.push('_');
541 }
542 result.push(ch.to_lowercase().next().unwrap());
543 }
544 result
545 }
546 Self::ScreamingSnakeCase => {
547 let mut result = String::new();
549 for (i, ch) in name.chars().enumerate() {
550 if ch.is_uppercase() && i > 0 {
551 result.push('_');
552 }
553 result.push(ch.to_uppercase().next().unwrap());
554 }
555 result
556 }
557 Self::KebabCase => {
558 let mut result = String::new();
560 for (i, ch) in name.chars().enumerate() {
561 if ch.is_uppercase() && i > 0 {
562 result.push('-');
563 }
564 result.push(ch.to_lowercase().next().unwrap());
565 }
566 result
567 }
568 Self::ScreamingKebabCase => {
569 let mut result = String::new();
571 for (i, ch) in name.chars().enumerate() {
572 if ch.is_uppercase() && i > 0 {
573 result.push('-');
574 }
575 result.push(ch.to_uppercase().next().unwrap());
576 }
577 result
578 }
579 }
580 }
581}
582
583fn parse_serde_rename_all(attrs: &[syn::Attribute]) -> Option<RenameRule> {
585 for attr in attrs {
586 if attr.path().is_ident("serde")
587 && let Ok(meta_list) = attr.meta.require_list()
588 {
589 if let Ok(metas) =
591 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
592 {
593 for meta in metas {
594 if let Meta::NameValue(nv) = meta
595 && nv.path.is_ident("rename_all")
596 && let syn::Expr::Lit(syn::ExprLit {
597 lit: syn::Lit::Str(lit_str),
598 ..
599 }) = nv.value
600 {
601 return RenameRule::from_str(&lit_str.value());
602 }
603 }
604 }
605 }
606 }
607 None
608}
609
610fn parse_serde_tag(attrs: &[syn::Attribute]) -> Option<String> {
613 for attr in attrs {
614 if attr.path().is_ident("serde")
615 && let Ok(meta_list) = attr.meta.require_list()
616 {
617 if let Ok(metas) =
619 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
620 {
621 for meta in metas {
622 if let Meta::NameValue(nv) = meta
623 && nv.path.is_ident("tag")
624 && let syn::Expr::Lit(syn::ExprLit {
625 lit: syn::Lit::Str(lit_str),
626 ..
627 }) = nv.value
628 {
629 return Some(lit_str.value());
630 }
631 }
632 }
633 }
634 }
635 None
636}
637
638fn parse_serde_untagged(attrs: &[syn::Attribute]) -> bool {
641 for attr in attrs {
642 if attr.path().is_ident("serde")
643 && let Ok(meta_list) = attr.meta.require_list()
644 {
645 if let Ok(metas) =
647 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
648 {
649 for meta in metas {
650 if let Meta::Path(path) = meta
651 && path.is_ident("untagged")
652 {
653 return true;
654 }
655 }
656 }
657 }
658 }
659 false
660}
661
662#[derive(Debug, Default)]
664struct FieldPromptAttrs {
665 skip: bool,
666 rename: Option<String>,
667 format_with: Option<String>,
668 image: bool,
669 example: Option<String>,
670}
671
672fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
674 let mut result = FieldPromptAttrs::default();
675
676 for attr in attrs {
677 if attr.path().is_ident("prompt") {
678 if let Ok(meta_list) = attr.meta.require_list() {
680 if let Ok(metas) =
682 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
683 {
684 for meta in metas {
685 match meta {
686 Meta::Path(path) if path.is_ident("skip") => {
687 result.skip = true;
688 }
689 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
690 if let syn::Expr::Lit(syn::ExprLit {
691 lit: syn::Lit::Str(lit_str),
692 ..
693 }) = nv.value
694 {
695 result.rename = Some(lit_str.value());
696 }
697 }
698 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
699 if let syn::Expr::Lit(syn::ExprLit {
700 lit: syn::Lit::Str(lit_str),
701 ..
702 }) = nv.value
703 {
704 result.format_with = Some(lit_str.value());
705 }
706 }
707 Meta::Path(path) if path.is_ident("image") => {
708 result.image = true;
709 }
710 Meta::NameValue(nv) if nv.path.is_ident("example") => {
711 if let syn::Expr::Lit(syn::ExprLit {
712 lit: syn::Lit::Str(lit_str),
713 ..
714 }) = nv.value
715 {
716 result.example = Some(lit_str.value());
717 }
718 }
719 _ => {}
720 }
721 }
722 } else if meta_list.tokens.to_string() == "skip" {
723 result.skip = true;
725 } else if meta_list.tokens.to_string() == "image" {
726 result.image = true;
728 }
729 }
730 }
731 }
732
733 result
734}
735
736#[proc_macro_derive(ToPrompt, attributes(prompt))]
779pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
780 let input = parse_macro_input!(input as DeriveInput);
781
782 let found_crate =
783 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
784 let crate_path = match found_crate {
785 FoundCrate::Itself => {
786 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
788 quote!(::#ident)
789 }
790 FoundCrate::Name(name) => {
791 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
792 quote!(::#ident)
793 }
794 };
795
796 match &input.data {
798 Data::Enum(data_enum) => {
799 let enum_name = &input.ident;
801 let enum_docs = extract_doc_comments(&input.attrs);
802
803 let serde_tag = parse_serde_tag(&input.attrs);
805 let is_internally_tagged = serde_tag.is_some();
806 let is_untagged = parse_serde_untagged(&input.attrs);
807
808 let rename_rule = parse_serde_rename_all(&input.attrs);
810
811 let mut variant_lines = Vec::new();
824 let mut first_variant_name = None;
825
826 let mut example_unit: Option<String> = None;
828 let mut example_struct: Option<String> = None;
829 let mut example_tuple: Option<String> = None;
830
831 let mut nested_types: Vec<&syn::Type> = Vec::new();
833
834 for variant in &data_enum.variants {
835 let variant_name = &variant.ident;
836 let variant_name_str = variant_name.to_string();
837
838 let prompt_attrs = parse_prompt_attributes(&variant.attrs);
840
841 if prompt_attrs.skip {
843 continue;
844 }
845
846 let variant_value = if let Some(prompt_rename) = &prompt_attrs.rename {
852 prompt_rename.clone()
853 } else if let Some(serde_rename) = parse_serde_variant_rename(&variant.attrs) {
854 serde_rename
855 } else if let Some(rule) = rename_rule {
856 rule.apply(&variant_name_str)
857 } else {
858 variant_name_str.clone()
859 };
860
861 let variant_line = match &variant.fields {
863 syn::Fields::Unit => {
864 if example_unit.is_none() {
866 example_unit = Some(format!("\"{}\"", variant_value));
867 }
868
869 if let Some(desc) = &prompt_attrs.description {
871 format!(" | \"{}\" // {}", variant_value, desc)
872 } else {
873 let docs = extract_doc_comments(&variant.attrs);
874 if !docs.is_empty() {
875 format!(" | \"{}\" // {}", variant_value, docs)
876 } else {
877 format!(" | \"{}\"", variant_value)
878 }
879 }
880 }
881 syn::Fields::Named(fields) => {
882 let mut field_parts = Vec::new();
883 let mut example_field_parts = Vec::new();
884
885 if is_internally_tagged && let Some(tag_name) = &serde_tag {
887 field_parts.push(format!("{}: \"{}\"", tag_name, variant_value));
888 example_field_parts
889 .push(format!("{}: \"{}\"", tag_name, variant_value));
890 }
891
892 for field in &fields.named {
893 let field_name = field.ident.as_ref().unwrap().to_string();
894 let field_type = format_type_for_schema(&field.ty);
895 field_parts.push(format!("{}: {}", field_name, field_type.clone()));
896
897 if !is_primitive_type(&field.ty) {
899 nested_types.push(&field.ty);
900 }
901
902 let example_value = generate_example_value_for_type(&field_type);
904 example_field_parts.push(format!("{}: {}", field_name, example_value));
905 }
906
907 let field_str = field_parts.join(", ");
908 let example_field_str = example_field_parts.join(", ");
909
910 if example_struct.is_none() {
912 if is_untagged || is_internally_tagged {
913 example_struct = Some(format!("{{ {} }}", example_field_str));
914 } else {
915 example_struct = Some(format!(
916 "{{ \"{}\": {{ {} }} }}",
917 variant_value, example_field_str
918 ));
919 }
920 }
921
922 let comment = if let Some(desc) = &prompt_attrs.description {
923 format!(" // {}", desc)
924 } else {
925 let docs = extract_doc_comments(&variant.attrs);
926 if !docs.is_empty() {
927 format!(" // {}", docs)
928 } else if is_untagged {
929 format!(" // {}", variant_value)
931 } else {
932 String::new()
933 }
934 };
935
936 if is_untagged {
937 format!(" | {{ {} }}{}", field_str, comment)
939 } else if is_internally_tagged {
940 format!(" | {{ {} }}{}", field_str, comment)
942 } else {
943 format!(
945 " | {{ \"{}\": {{ {} }} }}{}",
946 variant_value, field_str, comment
947 )
948 }
949 }
950 syn::Fields::Unnamed(fields) => {
951 let field_types: Vec<String> = fields
952 .unnamed
953 .iter()
954 .map(|f| {
955 if !is_primitive_type(&f.ty) {
957 nested_types.push(&f.ty);
958 }
959 format_type_for_schema(&f.ty)
960 })
961 .collect();
962
963 let tuple_str = field_types.join(", ");
964
965 let example_values: Vec<String> = field_types
967 .iter()
968 .map(|type_str| generate_example_value_for_type(type_str))
969 .collect();
970 let example_tuple_str = example_values.join(", ");
971
972 if example_tuple.is_none() {
974 if is_untagged || is_internally_tagged {
975 example_tuple = Some(format!("[{}]", example_tuple_str));
976 } else {
977 example_tuple = Some(format!(
978 "{{ \"{}\": [{}] }}",
979 variant_value, example_tuple_str
980 ));
981 }
982 }
983
984 let comment = if let Some(desc) = &prompt_attrs.description {
985 format!(" // {}", desc)
986 } else {
987 let docs = extract_doc_comments(&variant.attrs);
988 if !docs.is_empty() {
989 format!(" // {}", docs)
990 } else if is_untagged {
991 format!(" // {}", variant_value)
993 } else {
994 String::new()
995 }
996 };
997
998 if is_untagged || is_internally_tagged {
999 format!(" | [{}]{}", tuple_str, comment)
1002 } else {
1003 format!(
1005 " | {{ \"{}\": [{}] }}{}",
1006 variant_value, tuple_str, comment
1007 )
1008 }
1009 }
1010 };
1011
1012 variant_lines.push(variant_line);
1013
1014 if first_variant_name.is_none() {
1015 first_variant_name = Some(variant_value);
1016 }
1017 }
1018
1019 let mut lines = Vec::new();
1021
1022 if !enum_docs.is_empty() {
1024 lines.push("/**".to_string());
1025 lines.push(format!(" * {}", enum_docs));
1026 lines.push(" */".to_string());
1027 }
1028
1029 lines.push(format!("type {} =", enum_name));
1031
1032 for line in &variant_lines {
1034 lines.push(line.clone());
1035 }
1036
1037 if let Some(last) = lines.last_mut()
1039 && !last.ends_with(';')
1040 {
1041 last.push(';');
1042 }
1043
1044 let mut examples = Vec::new();
1046 if let Some(ex) = example_unit {
1047 examples.push(ex);
1048 }
1049 if let Some(ex) = example_struct {
1050 examples.push(ex);
1051 }
1052 if let Some(ex) = example_tuple {
1053 examples.push(ex);
1054 }
1055
1056 if !examples.is_empty() {
1057 lines.push("".to_string()); if examples.len() == 1 {
1059 lines.push(format!("Example value: {}", examples[0]));
1060 } else {
1061 lines.push("Example values:".to_string());
1062 for ex in examples {
1063 lines.push(format!(" {}", ex));
1064 }
1065 }
1066 }
1067
1068 let nested_type_tokens: Vec<_> = nested_types
1070 .iter()
1071 .map(|field_ty| {
1072 quote! {
1073 {
1074 let type_schema = <#field_ty as #crate_path::prompt::ToPrompt>::prompt_schema();
1075 if !type_schema.is_empty() {
1076 format!("\n\n{}", type_schema)
1077 } else {
1078 String::new()
1079 }
1080 }
1081 }
1082 })
1083 .collect();
1084
1085 let prompt_string = if nested_type_tokens.is_empty() {
1086 let lines_str = lines.join("\n");
1087 quote! { #lines_str.to_string() }
1088 } else {
1089 let lines_str = lines.join("\n");
1090 quote! {
1091 {
1092 let mut result = String::from(#lines_str);
1093
1094 let nested_schemas: Vec<String> = vec![#(#nested_type_tokens),*];
1096 let mut seen_schemas = std::collections::HashSet::<String>::new();
1097
1098 for schema in nested_schemas {
1099 if !schema.is_empty() && seen_schemas.insert(schema.clone()) {
1100 result.push_str(&schema);
1101 }
1102 }
1103
1104 result
1105 }
1106 }
1107 };
1108 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1109
1110 let mut match_arms = Vec::new();
1112 for variant in &data_enum.variants {
1113 let variant_name = &variant.ident;
1114 let variant_name_str = variant_name.to_string();
1115
1116 let prompt_attrs = parse_prompt_attributes(&variant.attrs);
1118
1119 let variant_value = if let Some(prompt_rename) = &prompt_attrs.rename {
1125 prompt_rename.clone()
1126 } else if let Some(serde_rename) = parse_serde_variant_rename(&variant.attrs) {
1127 serde_rename
1128 } else if let Some(rule) = rename_rule {
1129 rule.apply(&variant_name_str)
1130 } else {
1131 variant_name_str.clone()
1132 };
1133
1134 match &variant.fields {
1136 syn::Fields::Unit => {
1137 if prompt_attrs.skip {
1139 match_arms.push(quote! {
1140 Self::#variant_name => stringify!(#variant_name).to_string()
1141 });
1142 } else if let Some(desc) = &prompt_attrs.description {
1143 match_arms.push(quote! {
1144 Self::#variant_name => format!("{}: {}", #variant_value, #desc)
1145 });
1146 } else {
1147 let variant_docs = extract_doc_comments(&variant.attrs);
1148 if !variant_docs.is_empty() {
1149 match_arms.push(quote! {
1150 Self::#variant_name => format!("{}: {}", #variant_value, #variant_docs)
1151 });
1152 } else {
1153 match_arms.push(quote! {
1154 Self::#variant_name => #variant_value.to_string()
1155 });
1156 }
1157 }
1158 }
1159 syn::Fields::Named(fields) => {
1160 let field_bindings: Vec<_> = fields
1162 .named
1163 .iter()
1164 .map(|f| f.ident.as_ref().unwrap())
1165 .collect();
1166
1167 let field_displays: Vec<_> = fields
1168 .named
1169 .iter()
1170 .map(|f| {
1171 let field_name = f.ident.as_ref().unwrap();
1172 let field_name_str = field_name.to_string();
1173 quote! {
1174 format!("{}: {:?}", #field_name_str, #field_name)
1175 }
1176 })
1177 .collect();
1178
1179 let doc_or_desc = if let Some(desc) = &prompt_attrs.description {
1180 desc.clone()
1181 } else {
1182 let docs = extract_doc_comments(&variant.attrs);
1183 if !docs.is_empty() {
1184 docs
1185 } else {
1186 String::new()
1187 }
1188 };
1189
1190 if doc_or_desc.is_empty() {
1191 match_arms.push(quote! {
1192 Self::#variant_name { #(#field_bindings),* } => {
1193 let fields = vec![#(#field_displays),*];
1194 format!("{} {{ {} }}", #variant_value, fields.join(", "))
1195 }
1196 });
1197 } else {
1198 match_arms.push(quote! {
1199 Self::#variant_name { #(#field_bindings),* } => {
1200 let fields = vec![#(#field_displays),*];
1201 format!("{}: {} {{ {} }}", #variant_value, #doc_or_desc, fields.join(", "))
1202 }
1203 });
1204 }
1205 }
1206 syn::Fields::Unnamed(fields) => {
1207 let field_count = fields.unnamed.len();
1209 let field_bindings: Vec<_> = (0..field_count)
1210 .map(|i| {
1211 syn::Ident::new(
1212 &format!("field{}", i),
1213 proc_macro2::Span::call_site(),
1214 )
1215 })
1216 .collect();
1217
1218 let field_displays: Vec<_> = field_bindings
1219 .iter()
1220 .map(|field_name| {
1221 quote! {
1222 format!("{:?}", #field_name)
1223 }
1224 })
1225 .collect();
1226
1227 let doc_or_desc = if let Some(desc) = &prompt_attrs.description {
1228 desc.clone()
1229 } else {
1230 let docs = extract_doc_comments(&variant.attrs);
1231 if !docs.is_empty() {
1232 docs
1233 } else {
1234 String::new()
1235 }
1236 };
1237
1238 if doc_or_desc.is_empty() {
1239 match_arms.push(quote! {
1240 Self::#variant_name(#(#field_bindings),*) => {
1241 let fields = vec![#(#field_displays),*];
1242 format!("{}({})", #variant_value, fields.join(", "))
1243 }
1244 });
1245 } else {
1246 match_arms.push(quote! {
1247 Self::#variant_name(#(#field_bindings),*) => {
1248 let fields = vec![#(#field_displays),*];
1249 format!("{}: {}({})", #variant_value, #doc_or_desc, fields.join(", "))
1250 }
1251 });
1252 }
1253 }
1254 }
1255 }
1256
1257 let to_prompt_impl = if match_arms.is_empty() {
1258 quote! {
1260 fn to_prompt(&self) -> String {
1261 match *self {}
1262 }
1263 }
1264 } else {
1265 quote! {
1266 fn to_prompt(&self) -> String {
1267 match self {
1268 #(#match_arms),*
1269 }
1270 }
1271 }
1272 };
1273
1274 let expanded = quote! {
1275 impl #impl_generics #crate_path::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
1276 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1277 vec![#crate_path::prompt::PromptPart::Text(self.to_prompt())]
1278 }
1279
1280 #to_prompt_impl
1281
1282 fn prompt_schema() -> String {
1283 #prompt_string
1284 }
1285 }
1286 };
1287
1288 TokenStream::from(expanded)
1289 }
1290 Data::Struct(data_struct) => {
1291 let mut template_attr = None;
1293 let mut template_file_attr = None;
1294 let mut mode_attr = None;
1295 let mut validate_attr = false;
1296 let mut type_marker_attr = false;
1297
1298 for attr in &input.attrs {
1299 if attr.path().is_ident("prompt") {
1300 if let Ok(metas) =
1302 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1303 {
1304 for meta in metas {
1305 match meta {
1306 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1307 if let syn::Expr::Lit(expr_lit) = nv.value
1308 && let syn::Lit::Str(lit_str) = expr_lit.lit
1309 {
1310 template_attr = Some(lit_str.value());
1311 }
1312 }
1313 Meta::NameValue(nv) if nv.path.is_ident("template_file") => {
1314 if let syn::Expr::Lit(expr_lit) = nv.value
1315 && let syn::Lit::Str(lit_str) = expr_lit.lit
1316 {
1317 template_file_attr = Some(lit_str.value());
1318 }
1319 }
1320 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
1321 if let syn::Expr::Lit(expr_lit) = nv.value
1322 && let syn::Lit::Str(lit_str) = expr_lit.lit
1323 {
1324 mode_attr = Some(lit_str.value());
1325 }
1326 }
1327 Meta::NameValue(nv) if nv.path.is_ident("validate") => {
1328 if let syn::Expr::Lit(expr_lit) = nv.value
1329 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
1330 {
1331 validate_attr = lit_bool.value();
1332 }
1333 }
1334 Meta::NameValue(nv) if nv.path.is_ident("type_marker") => {
1335 if let syn::Expr::Lit(expr_lit) = nv.value
1336 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
1337 {
1338 type_marker_attr = lit_bool.value();
1339 }
1340 }
1341 Meta::Path(path) if path.is_ident("type_marker") => {
1342 type_marker_attr = true;
1344 }
1345 _ => {}
1346 }
1347 }
1348 }
1349 }
1350 }
1351
1352 if template_attr.is_some() && template_file_attr.is_some() {
1354 return syn::Error::new(
1355 input.ident.span(),
1356 "The `template` and `template_file` attributes are mutually exclusive. Please use only one.",
1357 ).to_compile_error().into();
1358 }
1359
1360 let template_str = if let Some(file_path) = template_file_attr {
1362 let mut full_path = None;
1366
1367 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
1369 let is_trybuild = manifest_dir.contains("target/tests/trybuild");
1371
1372 if !is_trybuild {
1373 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
1375 if candidate.exists() {
1376 full_path = Some(candidate);
1377 }
1378 } else {
1379 if let Some(target_pos) = manifest_dir.find("/target/tests/trybuild") {
1385 let workspace_root = &manifest_dir[..target_pos];
1386 let original_macros_dir = std::path::Path::new(workspace_root)
1388 .join("crates")
1389 .join("llm-toolkit-macros");
1390
1391 let candidate = original_macros_dir.join(&file_path);
1392 if candidate.exists() {
1393 full_path = Some(candidate);
1394 }
1395 }
1396 }
1397 }
1398
1399 if full_path.is_none() {
1401 let candidate = std::path::Path::new(&file_path).to_path_buf();
1402 if candidate.exists() {
1403 full_path = Some(candidate);
1404 }
1405 }
1406
1407 if full_path.is_none()
1410 && let Ok(current_dir) = std::env::current_dir()
1411 {
1412 let mut search_dir = current_dir.as_path();
1413 for _ in 0..10 {
1415 let macros_dir = search_dir.join("crates/llm-toolkit-macros");
1417 if macros_dir.exists() {
1418 let candidate = macros_dir.join(&file_path);
1419 if candidate.exists() {
1420 full_path = Some(candidate);
1421 break;
1422 }
1423 }
1424 let candidate = search_dir.join(&file_path);
1426 if candidate.exists() {
1427 full_path = Some(candidate);
1428 break;
1429 }
1430 if let Some(parent) = search_dir.parent() {
1431 search_dir = parent;
1432 } else {
1433 break;
1434 }
1435 }
1436 }
1437
1438 if full_path.is_none() {
1440 let mut error_msg = format!(
1442 "Template file '{}' not found at compile time.\n\nSearched in:",
1443 file_path
1444 );
1445
1446 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
1447 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
1448 error_msg.push_str(&format!("\n - {}", candidate.display()));
1449 }
1450
1451 if let Ok(current_dir) = std::env::current_dir() {
1452 let candidate = current_dir.join(&file_path);
1453 error_msg.push_str(&format!("\n - {}", candidate.display()));
1454 }
1455
1456 error_msg.push_str("\n\nPlease ensure:");
1457 error_msg.push_str("\n 1. The template file exists");
1458 error_msg.push_str("\n 2. The path is relative to CARGO_MANIFEST_DIR");
1459 error_msg.push_str("\n 3. There are no typos in the path");
1460
1461 return syn::Error::new(input.ident.span(), error_msg)
1462 .to_compile_error()
1463 .into();
1464 }
1465
1466 let final_path = full_path.unwrap();
1467
1468 match std::fs::read_to_string(&final_path) {
1470 Ok(content) => Some(content),
1471 Err(e) => {
1472 return syn::Error::new(
1473 input.ident.span(),
1474 format!(
1475 "Failed to read template file '{}': {}\n\nPath resolved to: {}",
1476 file_path,
1477 e,
1478 final_path.display()
1479 ),
1480 )
1481 .to_compile_error()
1482 .into();
1483 }
1484 }
1485 } else {
1486 template_attr
1487 };
1488
1489 if validate_attr && let Some(template) = &template_str {
1491 let mut env = minijinja::Environment::new();
1493 if let Err(e) = env.add_template("validation", template) {
1494 let warning_msg =
1496 format!("Template validation warning: Invalid Jinja syntax - {}", e);
1497 let warning_ident = syn::Ident::new(
1498 "TEMPLATE_VALIDATION_WARNING",
1499 proc_macro2::Span::call_site(),
1500 );
1501 let _warning_tokens = quote! {
1502 #[deprecated(note = #warning_msg)]
1503 const #warning_ident: () = ();
1504 let _ = #warning_ident;
1505 };
1506 eprintln!("cargo:warning={}", warning_msg);
1508 }
1509
1510 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1512 &fields.named
1513 } else {
1514 panic!("Template validation is only supported for structs with named fields.");
1515 };
1516
1517 let field_names: std::collections::HashSet<String> = fields
1518 .iter()
1519 .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
1520 .collect();
1521
1522 let placeholders = parse_template_placeholders_with_mode(template);
1524
1525 for (placeholder_name, _mode) in &placeholders {
1526 if placeholder_name != "self" && !field_names.contains(placeholder_name) {
1527 let warning_msg = format!(
1528 "Template validation warning: Variable '{}' used in template but not found in struct fields",
1529 placeholder_name
1530 );
1531 eprintln!("cargo:warning={}", warning_msg);
1532 }
1533 }
1534 }
1535
1536 let name = input.ident;
1537 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1538
1539 let struct_docs = extract_doc_comments(&input.attrs);
1541
1542 let is_mode_based =
1544 mode_attr.is_some() || (template_str.is_none() && struct_docs.contains("mode"));
1545
1546 let expanded = if is_mode_based || mode_attr.is_some() {
1547 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1549 &fields.named
1550 } else {
1551 panic!(
1552 "Mode-based prompt generation is only supported for structs with named fields."
1553 );
1554 };
1555
1556 let struct_name_str = name.to_string();
1557
1558 let has_default = input.attrs.iter().any(|attr| {
1560 if attr.path().is_ident("derive")
1561 && let Ok(meta_list) = attr.meta.require_list()
1562 {
1563 let tokens_str = meta_list.tokens.to_string();
1564 tokens_str.contains("Default")
1565 } else {
1566 false
1567 }
1568 });
1569
1570 let schema_parts = generate_schema_only_parts(
1581 &struct_name_str,
1582 &struct_docs,
1583 fields,
1584 &crate_path,
1585 type_marker_attr,
1586 );
1587
1588 let example_parts = generate_example_only_parts(fields, has_default, &crate_path);
1590
1591 quote! {
1592 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1593 fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<#crate_path::prompt::PromptPart> {
1594 match mode {
1595 "schema_only" => #schema_parts,
1596 "example_only" => #example_parts,
1597 "full" | _ => {
1598 let mut parts = Vec::new();
1600
1601 let schema_parts = #schema_parts;
1603 parts.extend(schema_parts);
1604
1605 parts.push(#crate_path::prompt::PromptPart::Text("\n### Example".to_string()));
1607 parts.push(#crate_path::prompt::PromptPart::Text(
1608 format!("Here is an example of a valid `{}` object:", #struct_name_str)
1609 ));
1610
1611 let example_parts = #example_parts;
1613 parts.extend(example_parts);
1614
1615 parts
1616 }
1617 }
1618 }
1619
1620 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1621 self.to_prompt_parts_with_mode("full")
1622 }
1623
1624 fn to_prompt(&self) -> String {
1625 self.to_prompt_parts()
1626 .into_iter()
1627 .filter_map(|part| match part {
1628 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1629 _ => None,
1630 })
1631 .collect::<Vec<_>>()
1632 .join("\n")
1633 }
1634
1635 fn prompt_schema() -> String {
1636 use std::sync::OnceLock;
1637 static SCHEMA_CACHE: OnceLock<String> = OnceLock::new();
1638
1639 SCHEMA_CACHE.get_or_init(|| {
1640 let schema_parts = #schema_parts;
1641 schema_parts
1642 .into_iter()
1643 .filter_map(|part| match part {
1644 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1645 _ => None,
1646 })
1647 .collect::<Vec<_>>()
1648 .join("\n")
1649 }).clone()
1650 }
1651 }
1652 }
1653 } else if let Some(template) = template_str {
1654 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1657 &fields.named
1658 } else {
1659 panic!(
1660 "Template prompt generation is only supported for structs with named fields."
1661 );
1662 };
1663
1664 let placeholders = parse_template_placeholders_with_mode(&template);
1666 let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
1668 mode.is_some()
1669 && fields
1670 .iter()
1671 .any(|f| f.ident.as_ref().unwrap() == field_name)
1672 });
1673
1674 let mut image_field_parts = Vec::new();
1675 for f in fields.iter() {
1676 let field_name = f.ident.as_ref().unwrap();
1677 let attrs = parse_field_prompt_attrs(&f.attrs);
1678
1679 if attrs.image {
1680 image_field_parts.push(quote! {
1682 parts.extend(self.#field_name.to_prompt_parts());
1683 });
1684 }
1685 }
1686
1687 if has_mode_syntax {
1689 let mut context_fields = Vec::new();
1691 let mut modified_template = template.clone();
1692
1693 for (field_name, mode_opt) in &placeholders {
1695 if let Some(mode) = mode_opt {
1696 let unique_key = format!("{}__{}", field_name, mode);
1698
1699 let pattern = format!("{{{{ {}:{} }}}}", field_name, mode);
1701 let replacement = format!("{{{{ {} }}}}", unique_key);
1702 modified_template = modified_template.replace(&pattern, &replacement);
1703
1704 let field_ident =
1706 syn::Ident::new(field_name, proc_macro2::Span::call_site());
1707
1708 context_fields.push(quote! {
1710 context.insert(
1711 #unique_key.to_string(),
1712 minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
1713 );
1714 });
1715 }
1716 }
1717
1718 for field in fields.iter() {
1720 let field_name = field.ident.as_ref().unwrap();
1721 let field_name_str = field_name.to_string();
1722
1723 let has_mode_entry = placeholders
1725 .iter()
1726 .any(|(name, mode)| name == &field_name_str && mode.is_some());
1727
1728 if !has_mode_entry {
1729 let is_primitive = match &field.ty {
1732 syn::Type::Path(type_path) => {
1733 if let Some(segment) = type_path.path.segments.last() {
1734 let type_name = segment.ident.to_string();
1735 matches!(
1736 type_name.as_str(),
1737 "String"
1738 | "str"
1739 | "i8"
1740 | "i16"
1741 | "i32"
1742 | "i64"
1743 | "i128"
1744 | "isize"
1745 | "u8"
1746 | "u16"
1747 | "u32"
1748 | "u64"
1749 | "u128"
1750 | "usize"
1751 | "f32"
1752 | "f64"
1753 | "bool"
1754 | "char"
1755 )
1756 } else {
1757 false
1758 }
1759 }
1760 _ => false,
1761 };
1762
1763 if is_primitive {
1764 context_fields.push(quote! {
1765 context.insert(
1766 #field_name_str.to_string(),
1767 minijinja::Value::from_serialize(&self.#field_name)
1768 );
1769 });
1770 } else {
1771 context_fields.push(quote! {
1773 context.insert(
1774 #field_name_str.to_string(),
1775 minijinja::Value::from(self.#field_name.to_prompt())
1776 );
1777 });
1778 }
1779 }
1780 }
1781
1782 quote! {
1783 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1784 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1785 let mut parts = Vec::new();
1786
1787 #(#image_field_parts)*
1789
1790 let text = {
1792 let mut env = minijinja::Environment::new();
1793 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1794 panic!("Failed to parse template: {}", e)
1795 });
1796
1797 let tmpl = env.get_template("prompt").unwrap();
1798
1799 let mut context = std::collections::HashMap::new();
1800 #(#context_fields)*
1801
1802 tmpl.render(context).unwrap_or_else(|e| {
1803 format!("Failed to render prompt: {}", e)
1804 })
1805 };
1806
1807 if !text.is_empty() {
1808 parts.push(#crate_path::prompt::PromptPart::Text(text));
1809 }
1810
1811 parts
1812 }
1813
1814 fn to_prompt(&self) -> String {
1815 let mut env = minijinja::Environment::new();
1817 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1818 panic!("Failed to parse template: {}", e)
1819 });
1820
1821 let tmpl = env.get_template("prompt").unwrap();
1822
1823 let mut context = std::collections::HashMap::new();
1824 #(#context_fields)*
1825
1826 tmpl.render(context).unwrap_or_else(|e| {
1827 format!("Failed to render prompt: {}", e)
1828 })
1829 }
1830
1831 fn prompt_schema() -> String {
1832 String::new() }
1834 }
1835 }
1836 } else {
1837 quote! {
1839 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1840 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1841 let mut parts = Vec::new();
1842
1843 #(#image_field_parts)*
1845
1846 let text = #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1848 format!("Failed to render prompt: {}", e)
1849 });
1850 if !text.is_empty() {
1851 parts.push(#crate_path::prompt::PromptPart::Text(text));
1852 }
1853
1854 parts
1855 }
1856
1857 fn to_prompt(&self) -> String {
1858 #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1859 format!("Failed to render prompt: {}", e)
1860 })
1861 }
1862
1863 fn prompt_schema() -> String {
1864 String::new() }
1866 }
1867 }
1868 }
1869 } else {
1870 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1873 &fields.named
1874 } else {
1875 panic!(
1876 "Default prompt generation is only supported for structs with named fields."
1877 );
1878 };
1879
1880 let mut text_field_parts = Vec::new();
1882 let mut image_field_parts = Vec::new();
1883
1884 for f in fields.iter() {
1885 let field_name = f.ident.as_ref().unwrap();
1886 let attrs = parse_field_prompt_attrs(&f.attrs);
1887
1888 if attrs.skip {
1890 continue;
1891 }
1892
1893 if attrs.image {
1894 image_field_parts.push(quote! {
1896 parts.extend(self.#field_name.to_prompt_parts());
1897 });
1898 } else {
1899 let key = if let Some(rename) = attrs.rename {
1905 rename
1906 } else {
1907 let doc_comment = extract_doc_comments(&f.attrs);
1908 if !doc_comment.is_empty() {
1909 doc_comment
1910 } else {
1911 field_name.to_string()
1912 }
1913 };
1914
1915 let value_expr = if let Some(format_with) = attrs.format_with {
1917 let func_path: syn::Path =
1919 syn::parse_str(&format_with).unwrap_or_else(|_| {
1920 panic!("Invalid function path: {}", format_with)
1921 });
1922 quote! { #func_path(&self.#field_name) }
1923 } else {
1924 quote! { self.#field_name.to_prompt() }
1925 };
1926
1927 text_field_parts.push(quote! {
1928 text_parts.push(format!("{}: {}", #key, #value_expr));
1929 });
1930 }
1931 }
1932
1933 let struct_name_str = name.to_string();
1935 let schema_parts = generate_schema_only_parts(
1936 &struct_name_str,
1937 &struct_docs,
1938 fields,
1939 &crate_path,
1940 false, );
1942
1943 quote! {
1945 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1946 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1947 let mut parts = Vec::new();
1948
1949 #(#image_field_parts)*
1951
1952 let mut text_parts = Vec::new();
1954 #(#text_field_parts)*
1955
1956 if !text_parts.is_empty() {
1957 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1958 }
1959
1960 parts
1961 }
1962
1963 fn to_prompt(&self) -> String {
1964 let mut text_parts = Vec::new();
1965 #(#text_field_parts)*
1966 text_parts.join("\n")
1967 }
1968
1969 fn prompt_schema() -> String {
1970 use std::sync::OnceLock;
1971 static SCHEMA_CACHE: OnceLock<String> = OnceLock::new();
1972
1973 SCHEMA_CACHE.get_or_init(|| {
1974 let schema_parts = #schema_parts;
1975 schema_parts
1976 .into_iter()
1977 .filter_map(|part| match part {
1978 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1979 _ => None,
1980 })
1981 .collect::<Vec<_>>()
1982 .join("\n")
1983 }).clone()
1984 }
1985 }
1986 }
1987 };
1988
1989 TokenStream::from(expanded)
1990 }
1991 Data::Union(_) => {
1992 panic!("`#[derive(ToPrompt)]` is not supported for unions");
1993 }
1994 }
1995}
1996
1997#[derive(Debug, Clone)]
1999struct TargetInfo {
2000 name: String,
2001 template: Option<String>,
2002 field_configs: std::collections::HashMap<String, FieldTargetConfig>,
2003}
2004
2005#[derive(Debug, Clone, Default)]
2007struct FieldTargetConfig {
2008 skip: bool,
2009 rename: Option<String>,
2010 format_with: Option<String>,
2011 image: bool,
2012 include_only: bool, }
2014
2015fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
2017 let mut configs = Vec::new();
2018
2019 for attr in attrs {
2020 if attr.path().is_ident("prompt_for")
2021 && let Ok(meta_list) = attr.meta.require_list()
2022 {
2023 if meta_list.tokens.to_string() == "skip" {
2025 let config = FieldTargetConfig {
2027 skip: true,
2028 ..Default::default()
2029 };
2030 configs.push(("*".to_string(), config));
2031 } else if let Ok(metas) =
2032 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2033 {
2034 let mut target_name = None;
2035 let mut config = FieldTargetConfig::default();
2036
2037 for meta in metas {
2038 match meta {
2039 Meta::NameValue(nv) if nv.path.is_ident("name") => {
2040 if let syn::Expr::Lit(syn::ExprLit {
2041 lit: syn::Lit::Str(lit_str),
2042 ..
2043 }) = nv.value
2044 {
2045 target_name = Some(lit_str.value());
2046 }
2047 }
2048 Meta::Path(path) if path.is_ident("skip") => {
2049 config.skip = true;
2050 }
2051 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
2052 if let syn::Expr::Lit(syn::ExprLit {
2053 lit: syn::Lit::Str(lit_str),
2054 ..
2055 }) = nv.value
2056 {
2057 config.rename = Some(lit_str.value());
2058 }
2059 }
2060 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
2061 if let syn::Expr::Lit(syn::ExprLit {
2062 lit: syn::Lit::Str(lit_str),
2063 ..
2064 }) = nv.value
2065 {
2066 config.format_with = Some(lit_str.value());
2067 }
2068 }
2069 Meta::Path(path) if path.is_ident("image") => {
2070 config.image = true;
2071 }
2072 _ => {}
2073 }
2074 }
2075
2076 if let Some(name) = target_name {
2077 config.include_only = true;
2078 configs.push((name, config));
2079 }
2080 }
2081 }
2082 }
2083
2084 configs
2085}
2086
2087fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
2089 let mut targets = Vec::new();
2090
2091 for attr in attrs {
2092 if attr.path().is_ident("prompt_for")
2093 && let Ok(meta_list) = attr.meta.require_list()
2094 && let Ok(metas) =
2095 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2096 {
2097 let mut target_name = None;
2098 let mut template = None;
2099
2100 for meta in metas {
2101 match meta {
2102 Meta::NameValue(nv) if nv.path.is_ident("name") => {
2103 if let syn::Expr::Lit(syn::ExprLit {
2104 lit: syn::Lit::Str(lit_str),
2105 ..
2106 }) = nv.value
2107 {
2108 target_name = Some(lit_str.value());
2109 }
2110 }
2111 Meta::NameValue(nv) if nv.path.is_ident("template") => {
2112 if let syn::Expr::Lit(syn::ExprLit {
2113 lit: syn::Lit::Str(lit_str),
2114 ..
2115 }) = nv.value
2116 {
2117 template = Some(lit_str.value());
2118 }
2119 }
2120 _ => {}
2121 }
2122 }
2123
2124 if let Some(name) = target_name {
2125 targets.push(TargetInfo {
2126 name,
2127 template,
2128 field_configs: std::collections::HashMap::new(),
2129 });
2130 }
2131 }
2132 }
2133
2134 targets
2135}
2136
2137#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
2138pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
2139 let input = parse_macro_input!(input as DeriveInput);
2140
2141 let found_crate =
2142 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2143 let crate_path = match found_crate {
2144 FoundCrate::Itself => {
2145 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2147 quote!(::#ident)
2148 }
2149 FoundCrate::Name(name) => {
2150 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2151 quote!(::#ident)
2152 }
2153 };
2154
2155 let data_struct = match &input.data {
2157 Data::Struct(data) => data,
2158 _ => {
2159 return syn::Error::new(
2160 input.ident.span(),
2161 "`#[derive(ToPromptSet)]` is only supported for structs",
2162 )
2163 .to_compile_error()
2164 .into();
2165 }
2166 };
2167
2168 let fields = match &data_struct.fields {
2169 syn::Fields::Named(fields) => &fields.named,
2170 _ => {
2171 return syn::Error::new(
2172 input.ident.span(),
2173 "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
2174 )
2175 .to_compile_error()
2176 .into();
2177 }
2178 };
2179
2180 let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
2182
2183 for field in fields.iter() {
2185 let field_name = field.ident.as_ref().unwrap().to_string();
2186 let field_configs = parse_prompt_for_attrs(&field.attrs);
2187
2188 for (target_name, config) in field_configs {
2189 if target_name == "*" {
2190 for target in &mut targets {
2192 target
2193 .field_configs
2194 .entry(field_name.clone())
2195 .or_insert_with(FieldTargetConfig::default)
2196 .skip = config.skip;
2197 }
2198 } else {
2199 let target_exists = targets.iter().any(|t| t.name == target_name);
2201 if !target_exists {
2202 targets.push(TargetInfo {
2204 name: target_name.clone(),
2205 template: None,
2206 field_configs: std::collections::HashMap::new(),
2207 });
2208 }
2209
2210 let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
2211
2212 target.field_configs.insert(field_name.clone(), config);
2213 }
2214 }
2215 }
2216
2217 let mut match_arms = Vec::new();
2219
2220 for target in &targets {
2221 let target_name = &target.name;
2222
2223 if let Some(template_str) = &target.template {
2224 let mut image_parts = Vec::new();
2226
2227 for field in fields.iter() {
2228 let field_name = field.ident.as_ref().unwrap();
2229 let field_name_str = field_name.to_string();
2230
2231 if let Some(config) = target.field_configs.get(&field_name_str)
2232 && config.image
2233 {
2234 image_parts.push(quote! {
2235 parts.extend(self.#field_name.to_prompt_parts());
2236 });
2237 }
2238 }
2239
2240 match_arms.push(quote! {
2241 #target_name => {
2242 let mut parts = Vec::new();
2243
2244 #(#image_parts)*
2245
2246 let text = #crate_path::prompt::render_prompt(#template_str, self)
2247 .map_err(|e| #crate_path::prompt::PromptSetError::RenderFailed {
2248 target: #target_name.to_string(),
2249 source: e,
2250 })?;
2251
2252 if !text.is_empty() {
2253 parts.push(#crate_path::prompt::PromptPart::Text(text));
2254 }
2255
2256 Ok(parts)
2257 }
2258 });
2259 } else {
2260 let mut text_field_parts = Vec::new();
2262 let mut image_field_parts = Vec::new();
2263
2264 for field in fields.iter() {
2265 let field_name = field.ident.as_ref().unwrap();
2266 let field_name_str = field_name.to_string();
2267
2268 let config = target.field_configs.get(&field_name_str);
2270
2271 if let Some(cfg) = config
2273 && cfg.skip
2274 {
2275 continue;
2276 }
2277
2278 let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
2282 let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
2283 .iter()
2284 .any(|(name, _)| name != "*");
2285
2286 if has_any_target_specific_config && !is_explicitly_for_this_target {
2287 continue;
2288 }
2289
2290 if let Some(cfg) = config {
2291 if cfg.image {
2292 image_field_parts.push(quote! {
2293 parts.extend(self.#field_name.to_prompt_parts());
2294 });
2295 } else {
2296 let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
2297
2298 let value_expr = if let Some(format_with) = &cfg.format_with {
2299 match syn::parse_str::<syn::Path>(format_with) {
2301 Ok(func_path) => quote! { #func_path(&self.#field_name) },
2302 Err(_) => {
2303 let error_msg = format!(
2305 "Invalid function path in format_with: '{}'",
2306 format_with
2307 );
2308 quote! {
2309 compile_error!(#error_msg);
2310 String::new()
2311 }
2312 }
2313 }
2314 } else {
2315 quote! { self.#field_name.to_prompt() }
2316 };
2317
2318 text_field_parts.push(quote! {
2319 text_parts.push(format!("{}: {}", #key, #value_expr));
2320 });
2321 }
2322 } else {
2323 text_field_parts.push(quote! {
2325 text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
2326 });
2327 }
2328 }
2329
2330 match_arms.push(quote! {
2331 #target_name => {
2332 let mut parts = Vec::new();
2333
2334 #(#image_field_parts)*
2335
2336 let mut text_parts = Vec::new();
2337 #(#text_field_parts)*
2338
2339 if !text_parts.is_empty() {
2340 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
2341 }
2342
2343 Ok(parts)
2344 }
2345 });
2346 }
2347 }
2348
2349 let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
2351
2352 match_arms.push(quote! {
2354 _ => {
2355 let available = vec![#(#target_names.to_string()),*];
2356 Err(#crate_path::prompt::PromptSetError::TargetNotFound {
2357 target: target.to_string(),
2358 available,
2359 })
2360 }
2361 });
2362
2363 let struct_name = &input.ident;
2364 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2365
2366 let expanded = quote! {
2367 impl #impl_generics #crate_path::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
2368 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<#crate_path::prompt::PromptPart>, #crate_path::prompt::PromptSetError> {
2369 match target {
2370 #(#match_arms)*
2371 }
2372 }
2373 }
2374 };
2375
2376 TokenStream::from(expanded)
2377}
2378
2379struct TypeList {
2381 types: Punctuated<syn::Type, Token![,]>,
2382}
2383
2384impl Parse for TypeList {
2385 fn parse(input: ParseStream) -> syn::Result<Self> {
2386 Ok(TypeList {
2387 types: Punctuated::parse_terminated(input)?,
2388 })
2389 }
2390}
2391
2392#[proc_macro]
2416pub fn examples_section(input: TokenStream) -> TokenStream {
2417 let input = parse_macro_input!(input as TypeList);
2418
2419 let found_crate =
2420 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2421 let _crate_path = match found_crate {
2422 FoundCrate::Itself => quote!(crate),
2423 FoundCrate::Name(name) => {
2424 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2425 quote!(::#ident)
2426 }
2427 };
2428
2429 let mut type_sections = Vec::new();
2431
2432 for ty in input.types.iter() {
2433 let type_name_str = quote!(#ty).to_string();
2435
2436 type_sections.push(quote! {
2438 {
2439 let type_name = #type_name_str;
2440 let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
2441 format!("---\n#### `{}`\n{}", type_name, json_example)
2442 }
2443 });
2444 }
2445
2446 let expanded = quote! {
2448 {
2449 let mut sections = Vec::new();
2450 sections.push("---".to_string());
2451 sections.push("### Examples".to_string());
2452 sections.push("".to_string());
2453 sections.push("Here are examples of the data structures you should use.".to_string());
2454 sections.push("".to_string());
2455
2456 #(sections.push(#type_sections);)*
2457
2458 sections.push("---".to_string());
2459
2460 sections.join("\n")
2461 }
2462 };
2463
2464 TokenStream::from(expanded)
2465}
2466
2467fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
2469 for attr in attrs {
2470 if attr.path().is_ident("prompt_for")
2471 && let Ok(meta_list) = attr.meta.require_list()
2472 && let Ok(metas) =
2473 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2474 {
2475 let mut target_type = None;
2476 let mut template = None;
2477
2478 for meta in metas {
2479 match meta {
2480 Meta::NameValue(nv) if nv.path.is_ident("target") => {
2481 if let syn::Expr::Lit(syn::ExprLit {
2482 lit: syn::Lit::Str(lit_str),
2483 ..
2484 }) = nv.value
2485 {
2486 target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
2488 }
2489 }
2490 Meta::NameValue(nv) if nv.path.is_ident("template") => {
2491 if let syn::Expr::Lit(syn::ExprLit {
2492 lit: syn::Lit::Str(lit_str),
2493 ..
2494 }) = nv.value
2495 {
2496 template = Some(lit_str.value());
2497 }
2498 }
2499 _ => {}
2500 }
2501 }
2502
2503 if let (Some(target), Some(tmpl)) = (target_type, template) {
2504 return (target, tmpl);
2505 }
2506 }
2507 }
2508
2509 panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
2510}
2511
2512#[proc_macro_attribute]
2546pub fn define_intent(_attr: TokenStream, item: TokenStream) -> TokenStream {
2547 let input = parse_macro_input!(item as DeriveInput);
2548
2549 let found_crate =
2550 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2551 let crate_path = match found_crate {
2552 FoundCrate::Itself => {
2553 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2555 quote!(::#ident)
2556 }
2557 FoundCrate::Name(name) => {
2558 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2559 quote!(::#ident)
2560 }
2561 };
2562
2563 let enum_data = match &input.data {
2565 Data::Enum(data) => data,
2566 _ => {
2567 return syn::Error::new(
2568 input.ident.span(),
2569 "`#[define_intent]` can only be applied to enums",
2570 )
2571 .to_compile_error()
2572 .into();
2573 }
2574 };
2575
2576 let mut prompt_template = None;
2578 let mut extractor_tag = None;
2579 let mut mode = None;
2580
2581 for attr in &input.attrs {
2582 if attr.path().is_ident("intent")
2583 && let Ok(metas) =
2584 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2585 {
2586 for meta in metas {
2587 match meta {
2588 Meta::NameValue(nv) if nv.path.is_ident("prompt") => {
2589 if let syn::Expr::Lit(syn::ExprLit {
2590 lit: syn::Lit::Str(lit_str),
2591 ..
2592 }) = nv.value
2593 {
2594 prompt_template = Some(lit_str.value());
2595 }
2596 }
2597 Meta::NameValue(nv) if nv.path.is_ident("extractor_tag") => {
2598 if let syn::Expr::Lit(syn::ExprLit {
2599 lit: syn::Lit::Str(lit_str),
2600 ..
2601 }) = nv.value
2602 {
2603 extractor_tag = Some(lit_str.value());
2604 }
2605 }
2606 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
2607 if let syn::Expr::Lit(syn::ExprLit {
2608 lit: syn::Lit::Str(lit_str),
2609 ..
2610 }) = nv.value
2611 {
2612 mode = Some(lit_str.value());
2613 }
2614 }
2615 _ => {}
2616 }
2617 }
2618 }
2619 }
2620
2621 let mode = mode.unwrap_or_else(|| "single".to_string());
2623
2624 if mode != "single" && mode != "multi_tag" {
2626 return syn::Error::new(
2627 input.ident.span(),
2628 "`mode` must be either \"single\" or \"multi_tag\"",
2629 )
2630 .to_compile_error()
2631 .into();
2632 }
2633
2634 let prompt_template = match prompt_template {
2636 Some(p) => p,
2637 None => {
2638 return syn::Error::new(
2639 input.ident.span(),
2640 "`#[intent(...)]` attribute must include `prompt = \"...\"`",
2641 )
2642 .to_compile_error()
2643 .into();
2644 }
2645 };
2646
2647 if mode == "multi_tag" {
2649 let enum_name = &input.ident;
2650 let actions_doc = generate_multi_tag_actions_doc(&enum_data.variants);
2651 return generate_multi_tag_output(
2652 &input,
2653 enum_name,
2654 enum_data,
2655 prompt_template,
2656 actions_doc,
2657 );
2658 }
2659
2660 let extractor_tag = match extractor_tag {
2662 Some(t) => t,
2663 None => {
2664 return syn::Error::new(
2665 input.ident.span(),
2666 "`#[intent(...)]` attribute must include `extractor_tag = \"...\"`",
2667 )
2668 .to_compile_error()
2669 .into();
2670 }
2671 };
2672
2673 let enum_name = &input.ident;
2675 let enum_docs = extract_doc_comments(&input.attrs);
2676
2677 let mut intents_doc_lines = Vec::new();
2678
2679 if !enum_docs.is_empty() {
2681 intents_doc_lines.push(format!("{}: {}", enum_name, enum_docs));
2682 } else {
2683 intents_doc_lines.push(format!("{}:", enum_name));
2684 }
2685 intents_doc_lines.push(String::new()); intents_doc_lines.push("Possible values:".to_string());
2687
2688 for variant in &enum_data.variants {
2690 let variant_name = &variant.ident;
2691 let variant_docs = extract_doc_comments(&variant.attrs);
2692
2693 if !variant_docs.is_empty() {
2694 intents_doc_lines.push(format!("- {}: {}", variant_name, variant_docs));
2695 } else {
2696 intents_doc_lines.push(format!("- {}", variant_name));
2697 }
2698 }
2699
2700 let intents_doc_str = intents_doc_lines.join("\n");
2701
2702 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2704 let user_variables: Vec<String> = placeholders
2705 .iter()
2706 .filter_map(|(name, _)| {
2707 if name != "intents_doc" {
2708 Some(name.clone())
2709 } else {
2710 None
2711 }
2712 })
2713 .collect();
2714
2715 let enum_name_str = enum_name.to_string();
2717 let snake_case_name = to_snake_case(&enum_name_str);
2718 let function_name = syn::Ident::new(
2719 &format!("build_{}_prompt", snake_case_name),
2720 proc_macro2::Span::call_site(),
2721 );
2722
2723 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2725 .iter()
2726 .map(|var| {
2727 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2728 quote! { #ident: &str }
2729 })
2730 .collect();
2731
2732 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2734 .iter()
2735 .map(|var| {
2736 let var_str = var.clone();
2737 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2738 quote! {
2739 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2740 }
2741 })
2742 .collect();
2743
2744 let converted_template = prompt_template.clone();
2746
2747 let extractor_name = syn::Ident::new(
2749 &format!("{}Extractor", enum_name),
2750 proc_macro2::Span::call_site(),
2751 );
2752
2753 let filtered_attrs: Vec<_> = input
2755 .attrs
2756 .iter()
2757 .filter(|attr| !attr.path().is_ident("intent"))
2758 .collect();
2759
2760 let vis = &input.vis;
2762 let generics = &input.generics;
2763 let variants = &enum_data.variants;
2764 let enum_output = quote! {
2765 #(#filtered_attrs)*
2766 #vis enum #enum_name #generics {
2767 #variants
2768 }
2769 };
2770
2771 let expanded = quote! {
2773 #enum_output
2775
2776 pub fn #function_name(#(#function_params),*) -> String {
2778 let mut env = minijinja::Environment::new();
2779 env.add_template("prompt", #converted_template)
2780 .expect("Failed to parse intent prompt template");
2781
2782 let tmpl = env.get_template("prompt").unwrap();
2783
2784 let mut __template_context = std::collections::HashMap::new();
2785
2786 __template_context.insert("intents_doc".to_string(), minijinja::Value::from(#intents_doc_str));
2788
2789 #(#context_insertions)*
2791
2792 tmpl.render(&__template_context)
2793 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2794 }
2795
2796 pub struct #extractor_name;
2798
2799 impl #extractor_name {
2800 pub const EXTRACTOR_TAG: &'static str = #extractor_tag;
2801 }
2802
2803 impl #crate_path::intent::IntentExtractor<#enum_name> for #extractor_name {
2804 fn extract_intent(&self, response: &str) -> Result<#enum_name, #crate_path::intent::IntentExtractionError> {
2805 #crate_path::intent::extract_intent_from_response(response, Self::EXTRACTOR_TAG)
2807 }
2808 }
2809 };
2810
2811 TokenStream::from(expanded)
2812}
2813
2814fn to_snake_case(s: &str) -> String {
2816 let mut result = String::new();
2817 let mut prev_upper = false;
2818
2819 for (i, ch) in s.chars().enumerate() {
2820 if ch.is_uppercase() {
2821 if i > 0 && !prev_upper {
2822 result.push('_');
2823 }
2824 result.push(ch.to_lowercase().next().unwrap());
2825 prev_upper = true;
2826 } else {
2827 result.push(ch);
2828 prev_upper = false;
2829 }
2830 }
2831
2832 result
2833}
2834
2835#[derive(Debug, Default)]
2837struct ActionAttrs {
2838 tag: Option<String>,
2839}
2840
2841fn parse_action_attrs(attrs: &[syn::Attribute]) -> ActionAttrs {
2842 let mut result = ActionAttrs::default();
2843
2844 for attr in attrs {
2845 if attr.path().is_ident("action")
2846 && let Ok(meta_list) = attr.meta.require_list()
2847 && let Ok(metas) =
2848 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2849 {
2850 for meta in metas {
2851 if let Meta::NameValue(nv) = meta
2852 && nv.path.is_ident("tag")
2853 && let syn::Expr::Lit(syn::ExprLit {
2854 lit: syn::Lit::Str(lit_str),
2855 ..
2856 }) = nv.value
2857 {
2858 result.tag = Some(lit_str.value());
2859 }
2860 }
2861 }
2862 }
2863
2864 result
2865}
2866
2867#[derive(Debug, Default)]
2869struct FieldActionAttrs {
2870 is_attribute: bool,
2871 is_inner_text: bool,
2872}
2873
2874fn parse_field_action_attrs(attrs: &[syn::Attribute]) -> FieldActionAttrs {
2875 let mut result = FieldActionAttrs::default();
2876
2877 for attr in attrs {
2878 if attr.path().is_ident("action")
2879 && let Ok(meta_list) = attr.meta.require_list()
2880 {
2881 let tokens_str = meta_list.tokens.to_string();
2882 if tokens_str == "attribute" {
2883 result.is_attribute = true;
2884 } else if tokens_str == "inner_text" {
2885 result.is_inner_text = true;
2886 }
2887 }
2888 }
2889
2890 result
2891}
2892
2893fn generate_multi_tag_actions_doc(
2895 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2896) -> String {
2897 let mut doc_lines = Vec::new();
2898
2899 for variant in variants {
2900 let action_attrs = parse_action_attrs(&variant.attrs);
2901
2902 if let Some(tag) = action_attrs.tag {
2903 let variant_docs = extract_doc_comments(&variant.attrs);
2904
2905 match &variant.fields {
2906 syn::Fields::Unit => {
2907 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2909 }
2910 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
2911 doc_lines.push(format!("- `<{}>...</{}>`: {}", tag, tag, variant_docs));
2913 }
2914 syn::Fields::Named(fields) => {
2915 let mut attrs_str = Vec::new();
2917 let mut has_inner_text = false;
2918
2919 for field in &fields.named {
2920 let field_name = field.ident.as_ref().unwrap();
2921 let field_attrs = parse_field_action_attrs(&field.attrs);
2922
2923 if field_attrs.is_attribute {
2924 attrs_str.push(format!("{}=\"...\"", field_name));
2925 } else if field_attrs.is_inner_text {
2926 has_inner_text = true;
2927 }
2928 }
2929
2930 let attrs_part = if !attrs_str.is_empty() {
2931 format!(" {}", attrs_str.join(" "))
2932 } else {
2933 String::new()
2934 };
2935
2936 if has_inner_text {
2937 doc_lines.push(format!(
2938 "- `<{}{}>...</{}>`: {}",
2939 tag, attrs_part, tag, variant_docs
2940 ));
2941 } else if !attrs_str.is_empty() {
2942 doc_lines.push(format!("- `<{}{} />`: {}", tag, attrs_part, variant_docs));
2943 } else {
2944 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2945 }
2946
2947 for field in &fields.named {
2949 let field_name = field.ident.as_ref().unwrap();
2950 let field_attrs = parse_field_action_attrs(&field.attrs);
2951 let field_docs = extract_doc_comments(&field.attrs);
2952
2953 if field_attrs.is_attribute {
2954 doc_lines
2955 .push(format!(" - `{}` (attribute): {}", field_name, field_docs));
2956 } else if field_attrs.is_inner_text {
2957 doc_lines
2958 .push(format!(" - `{}` (inner_text): {}", field_name, field_docs));
2959 }
2960 }
2961 }
2962 _ => {
2963 }
2965 }
2966 }
2967 }
2968
2969 doc_lines.join("\n")
2970}
2971
2972fn generate_tags_regex(
2974 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2975) -> String {
2976 let mut tag_names = Vec::new();
2977
2978 for variant in variants {
2979 let action_attrs = parse_action_attrs(&variant.attrs);
2980 if let Some(tag) = action_attrs.tag {
2981 tag_names.push(tag);
2982 }
2983 }
2984
2985 if tag_names.is_empty() {
2986 return String::new();
2987 }
2988
2989 let tags_pattern = tag_names.join("|");
2990 format!(
2993 r"(?is)<(?:{})\b[^>]*/>|<(?:{})\b[^>]*>.*?</(?:{})>",
2994 tags_pattern, tags_pattern, tags_pattern
2995 )
2996}
2997
2998fn generate_multi_tag_output(
3000 input: &DeriveInput,
3001 enum_name: &syn::Ident,
3002 enum_data: &syn::DataEnum,
3003 prompt_template: String,
3004 actions_doc: String,
3005) -> TokenStream {
3006 let found_crate =
3007 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3008 let crate_path = match found_crate {
3009 FoundCrate::Itself => {
3010 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3012 quote!(::#ident)
3013 }
3014 FoundCrate::Name(name) => {
3015 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3016 quote!(::#ident)
3017 }
3018 };
3019
3020 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
3022 let user_variables: Vec<String> = placeholders
3023 .iter()
3024 .filter_map(|(name, _)| {
3025 if name != "actions_doc" {
3026 Some(name.clone())
3027 } else {
3028 None
3029 }
3030 })
3031 .collect();
3032
3033 let enum_name_str = enum_name.to_string();
3035 let snake_case_name = to_snake_case(&enum_name_str);
3036 let function_name = syn::Ident::new(
3037 &format!("build_{}_prompt", snake_case_name),
3038 proc_macro2::Span::call_site(),
3039 );
3040
3041 let function_params: Vec<proc_macro2::TokenStream> = user_variables
3043 .iter()
3044 .map(|var| {
3045 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
3046 quote! { #ident: &str }
3047 })
3048 .collect();
3049
3050 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
3052 .iter()
3053 .map(|var| {
3054 let var_str = var.clone();
3055 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
3056 quote! {
3057 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
3058 }
3059 })
3060 .collect();
3061
3062 let extractor_name = syn::Ident::new(
3064 &format!("{}Extractor", enum_name),
3065 proc_macro2::Span::call_site(),
3066 );
3067
3068 let filtered_attrs: Vec<_> = input
3070 .attrs
3071 .iter()
3072 .filter(|attr| !attr.path().is_ident("intent"))
3073 .collect();
3074
3075 let filtered_variants: Vec<proc_macro2::TokenStream> = enum_data
3077 .variants
3078 .iter()
3079 .map(|variant| {
3080 let variant_name = &variant.ident;
3081 let variant_attrs: Vec<_> = variant
3082 .attrs
3083 .iter()
3084 .filter(|attr| !attr.path().is_ident("action"))
3085 .collect();
3086 let fields = &variant.fields;
3087
3088 let filtered_fields = match fields {
3090 syn::Fields::Named(named_fields) => {
3091 let filtered: Vec<_> = named_fields
3092 .named
3093 .iter()
3094 .map(|field| {
3095 let field_name = &field.ident;
3096 let field_type = &field.ty;
3097 let field_vis = &field.vis;
3098 let filtered_attrs: Vec<_> = field
3099 .attrs
3100 .iter()
3101 .filter(|attr| !attr.path().is_ident("action"))
3102 .collect();
3103 quote! {
3104 #(#filtered_attrs)*
3105 #field_vis #field_name: #field_type
3106 }
3107 })
3108 .collect();
3109 quote! { { #(#filtered,)* } }
3110 }
3111 syn::Fields::Unnamed(unnamed_fields) => {
3112 let types: Vec<_> = unnamed_fields
3113 .unnamed
3114 .iter()
3115 .map(|field| {
3116 let field_type = &field.ty;
3117 quote! { #field_type }
3118 })
3119 .collect();
3120 quote! { (#(#types),*) }
3121 }
3122 syn::Fields::Unit => quote! {},
3123 };
3124
3125 quote! {
3126 #(#variant_attrs)*
3127 #variant_name #filtered_fields
3128 }
3129 })
3130 .collect();
3131
3132 let vis = &input.vis;
3133 let generics = &input.generics;
3134
3135 let parsing_arms = generate_parsing_arms(&enum_data.variants, enum_name);
3137
3138 let tags_regex = generate_tags_regex(&enum_data.variants);
3140
3141 let expanded = quote! {
3142 #(#filtered_attrs)*
3144 #vis enum #enum_name #generics {
3145 #(#filtered_variants),*
3146 }
3147
3148 pub fn #function_name(#(#function_params),*) -> String {
3150 let mut env = minijinja::Environment::new();
3151 env.add_template("prompt", #prompt_template)
3152 .expect("Failed to parse intent prompt template");
3153
3154 let tmpl = env.get_template("prompt").unwrap();
3155
3156 let mut __template_context = std::collections::HashMap::new();
3157
3158 __template_context.insert("actions_doc".to_string(), minijinja::Value::from(#actions_doc));
3160
3161 #(#context_insertions)*
3163
3164 tmpl.render(&__template_context)
3165 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
3166 }
3167
3168 pub struct #extractor_name;
3170
3171 impl #extractor_name {
3172 fn parse_single_action(&self, text: &str) -> Option<#enum_name> {
3173 use ::quick_xml::events::Event;
3174 use ::quick_xml::Reader;
3175
3176 let mut actions = Vec::new();
3177 let mut reader = Reader::from_str(text);
3178 reader.config_mut().trim_text(true);
3179
3180 let mut buf = Vec::new();
3181
3182 loop {
3183 match reader.read_event_into(&mut buf) {
3184 Ok(Event::Start(e)) => {
3185 let owned_e = e.into_owned();
3186 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
3187 let is_empty = false;
3188
3189 #parsing_arms
3190 }
3191 Ok(Event::Empty(e)) => {
3192 let owned_e = e.into_owned();
3193 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
3194 let is_empty = true;
3195
3196 #parsing_arms
3197 }
3198 Ok(Event::Eof) => break,
3199 Err(_) => {
3200 break;
3202 }
3203 _ => {}
3204 }
3205 buf.clear();
3206 }
3207
3208 actions.into_iter().next()
3209 }
3210
3211 pub fn extract_actions(&self, text: &str) -> Result<Vec<#enum_name>, #crate_path::intent::IntentError> {
3212 use ::quick_xml::events::Event;
3213 use ::quick_xml::Reader;
3214
3215 let mut actions = Vec::new();
3216 let mut reader = Reader::from_str(text);
3217 reader.config_mut().trim_text(true);
3218
3219 let mut buf = Vec::new();
3220
3221 loop {
3222 match reader.read_event_into(&mut buf) {
3223 Ok(Event::Start(e)) => {
3224 let owned_e = e.into_owned();
3225 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
3226 let is_empty = false;
3227
3228 #parsing_arms
3229 }
3230 Ok(Event::Empty(e)) => {
3231 let owned_e = e.into_owned();
3232 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
3233 let is_empty = true;
3234
3235 #parsing_arms
3236 }
3237 Ok(Event::Eof) => break,
3238 Err(_) => {
3239 break;
3241 }
3242 _ => {}
3243 }
3244 buf.clear();
3245 }
3246
3247 Ok(actions)
3248 }
3249
3250 pub fn transform_actions<F>(&self, text: &str, mut transformer: F) -> String
3251 where
3252 F: FnMut(#enum_name) -> String,
3253 {
3254 use ::regex::Regex;
3255
3256 let regex_pattern = #tags_regex;
3257 if regex_pattern.is_empty() {
3258 return text.to_string();
3259 }
3260
3261 let re = Regex::new(®ex_pattern).unwrap_or_else(|e| {
3262 panic!("Failed to compile regex for action tags: {}", e);
3263 });
3264
3265 re.replace_all(text, |caps: &::regex::Captures| {
3266 let matched = caps.get(0).map(|m| m.as_str()).unwrap_or("");
3267
3268 if let Some(action) = self.parse_single_action(matched) {
3270 transformer(action)
3271 } else {
3272 matched.to_string()
3274 }
3275 }).to_string()
3276 }
3277
3278 pub fn strip_actions(&self, text: &str) -> String {
3279 self.transform_actions(text, |_| String::new())
3280 }
3281 }
3282 };
3283
3284 TokenStream::from(expanded)
3285}
3286
3287fn generate_parsing_arms(
3289 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
3290 enum_name: &syn::Ident,
3291) -> proc_macro2::TokenStream {
3292 let mut arms = Vec::new();
3293
3294 for variant in variants {
3295 let variant_name = &variant.ident;
3296 let action_attrs = parse_action_attrs(&variant.attrs);
3297
3298 if let Some(tag) = action_attrs.tag {
3299 match &variant.fields {
3300 syn::Fields::Unit => {
3301 arms.push(quote! {
3303 if &tag_name == #tag {
3304 actions.push(#enum_name::#variant_name);
3305 }
3306 });
3307 }
3308 syn::Fields::Unnamed(_fields) => {
3309 arms.push(quote! {
3311 if &tag_name == #tag && !is_empty {
3312 match reader.read_text(owned_e.name()) {
3314 Ok(text) => {
3315 actions.push(#enum_name::#variant_name(text.to_string()));
3316 }
3317 Err(_) => {
3318 actions.push(#enum_name::#variant_name(String::new()));
3320 }
3321 }
3322 }
3323 });
3324 }
3325 syn::Fields::Named(fields) => {
3326 let mut field_names = Vec::new();
3328 let mut has_inner_text_field = None;
3329
3330 for field in &fields.named {
3331 let field_name = field.ident.as_ref().unwrap();
3332 let field_attrs = parse_field_action_attrs(&field.attrs);
3333
3334 if field_attrs.is_attribute {
3335 field_names.push(field_name.clone());
3336 } else if field_attrs.is_inner_text {
3337 has_inner_text_field = Some(field_name.clone());
3338 }
3339 }
3340
3341 if let Some(inner_text_field) = has_inner_text_field {
3342 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
3345 quote! {
3346 let mut #field_name = String::new();
3347 for attr in owned_e.attributes() {
3348 if let Ok(attr) = attr {
3349 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
3350 #field_name = String::from_utf8_lossy(&attr.value).to_string();
3351 break;
3352 }
3353 }
3354 }
3355 }
3356 }).collect();
3357
3358 arms.push(quote! {
3359 if &tag_name == #tag {
3360 #(#attr_extractions)*
3361
3362 if is_empty {
3364 let #inner_text_field = String::new();
3365 actions.push(#enum_name::#variant_name {
3366 #(#field_names,)*
3367 #inner_text_field,
3368 });
3369 } else {
3370 match reader.read_text(owned_e.name()) {
3372 Ok(text) => {
3373 let #inner_text_field = text.to_string();
3374 actions.push(#enum_name::#variant_name {
3375 #(#field_names,)*
3376 #inner_text_field,
3377 });
3378 }
3379 Err(_) => {
3380 let #inner_text_field = String::new();
3382 actions.push(#enum_name::#variant_name {
3383 #(#field_names,)*
3384 #inner_text_field,
3385 });
3386 }
3387 }
3388 }
3389 }
3390 });
3391 } else {
3392 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
3394 quote! {
3395 let mut #field_name = String::new();
3396 for attr in owned_e.attributes() {
3397 if let Ok(attr) = attr {
3398 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
3399 #field_name = String::from_utf8_lossy(&attr.value).to_string();
3400 break;
3401 }
3402 }
3403 }
3404 }
3405 }).collect();
3406
3407 arms.push(quote! {
3408 if &tag_name == #tag {
3409 #(#attr_extractions)*
3410 actions.push(#enum_name::#variant_name {
3411 #(#field_names),*
3412 });
3413 }
3414 });
3415 }
3416 }
3417 }
3418 }
3419 }
3420
3421 quote! {
3422 #(#arms)*
3423 }
3424}
3425
3426#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
3428pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
3429 let input = parse_macro_input!(input as DeriveInput);
3430
3431 let found_crate =
3432 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3433 let crate_path = match found_crate {
3434 FoundCrate::Itself => {
3435 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3437 quote!(::#ident)
3438 }
3439 FoundCrate::Name(name) => {
3440 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3441 quote!(::#ident)
3442 }
3443 };
3444
3445 let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
3447
3448 let struct_name = &input.ident;
3449 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
3450
3451 let placeholders = parse_template_placeholders_with_mode(&template);
3453
3454 let mut converted_template = template.clone();
3456 let mut context_fields = Vec::new();
3457
3458 let fields = match &input.data {
3460 Data::Struct(data_struct) => match &data_struct.fields {
3461 syn::Fields::Named(fields) => &fields.named,
3462 _ => panic!("ToPromptFor is only supported for structs with named fields"),
3463 },
3464 _ => panic!("ToPromptFor is only supported for structs"),
3465 };
3466
3467 let has_mode_support = input.attrs.iter().any(|attr| {
3469 if attr.path().is_ident("prompt")
3470 && let Ok(metas) =
3471 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
3472 {
3473 for meta in metas {
3474 if let Meta::NameValue(nv) = meta
3475 && nv.path.is_ident("mode")
3476 {
3477 return true;
3478 }
3479 }
3480 }
3481 false
3482 });
3483
3484 for (placeholder_name, mode_opt) in &placeholders {
3486 if placeholder_name == "self" {
3487 if let Some(specific_mode) = mode_opt {
3488 let unique_key = format!("self__{}", specific_mode);
3490
3491 let pattern = format!("{{{{ self:{} }}}}", specific_mode);
3493 let replacement = format!("{{{{ {} }}}}", unique_key);
3494 converted_template = converted_template.replace(&pattern, &replacement);
3495
3496 context_fields.push(quote! {
3498 context.insert(
3499 #unique_key.to_string(),
3500 minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
3501 );
3502 });
3503 } else {
3504 if has_mode_support {
3507 context_fields.push(quote! {
3509 context.insert(
3510 "self".to_string(),
3511 minijinja::Value::from(self.to_prompt_with_mode(mode))
3512 );
3513 });
3514 } else {
3515 context_fields.push(quote! {
3517 context.insert(
3518 "self".to_string(),
3519 minijinja::Value::from(self.to_prompt())
3520 );
3521 });
3522 }
3523 }
3524 } else {
3525 let field_exists = fields.iter().any(|f| {
3528 f.ident
3529 .as_ref()
3530 .is_some_and(|ident| ident == placeholder_name)
3531 });
3532
3533 if field_exists {
3534 let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
3535
3536 context_fields.push(quote! {
3540 context.insert(
3541 #placeholder_name.to_string(),
3542 minijinja::Value::from_serialize(&self.#field_ident)
3543 );
3544 });
3545 }
3546 }
3548 }
3549
3550 let expanded = quote! {
3551 impl #impl_generics #crate_path::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
3552 where
3553 #target_type: serde::Serialize,
3554 {
3555 fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
3556 let mut env = minijinja::Environment::new();
3558 env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
3559 panic!("Failed to parse template: {}", e)
3560 });
3561
3562 let tmpl = env.get_template("prompt").unwrap();
3563
3564 let mut context = std::collections::HashMap::new();
3566 context.insert(
3568 "self".to_string(),
3569 minijinja::Value::from_serialize(self)
3570 );
3571 context.insert(
3573 "target".to_string(),
3574 minijinja::Value::from_serialize(target)
3575 );
3576 #(#context_fields)*
3577
3578 tmpl.render(context).unwrap_or_else(|e| {
3580 format!("Failed to render prompt: {}", e)
3581 })
3582 }
3583 }
3584 };
3585
3586 TokenStream::from(expanded)
3587}
3588
3589struct AgentAttrs {
3595 expertise: Option<String>,
3596 output: Option<syn::Type>,
3597 backend: Option<String>,
3598 model: Option<String>,
3599 inner: Option<String>,
3600 default_inner: Option<String>,
3601 max_retries: Option<u32>,
3602 profile: Option<String>,
3603}
3604
3605impl Parse for AgentAttrs {
3606 fn parse(input: ParseStream) -> syn::Result<Self> {
3607 let mut expertise = None;
3608 let mut output = None;
3609 let mut backend = None;
3610 let mut model = None;
3611 let mut inner = None;
3612 let mut default_inner = None;
3613 let mut max_retries = None;
3614 let mut profile = None;
3615
3616 let pairs = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
3617
3618 for meta in pairs {
3619 match meta {
3620 Meta::NameValue(nv) if nv.path.is_ident("expertise") => {
3621 if let syn::Expr::Lit(syn::ExprLit {
3622 lit: syn::Lit::Str(lit_str),
3623 ..
3624 }) = &nv.value
3625 {
3626 expertise = Some(lit_str.value());
3627 }
3628 }
3629 Meta::NameValue(nv) if nv.path.is_ident("output") => {
3630 if let syn::Expr::Lit(syn::ExprLit {
3631 lit: syn::Lit::Str(lit_str),
3632 ..
3633 }) = &nv.value
3634 {
3635 let ty: syn::Type = syn::parse_str(&lit_str.value())?;
3636 output = Some(ty);
3637 }
3638 }
3639 Meta::NameValue(nv) if nv.path.is_ident("backend") => {
3640 if let syn::Expr::Lit(syn::ExprLit {
3641 lit: syn::Lit::Str(lit_str),
3642 ..
3643 }) = &nv.value
3644 {
3645 backend = Some(lit_str.value());
3646 }
3647 }
3648 Meta::NameValue(nv) if nv.path.is_ident("model") => {
3649 if let syn::Expr::Lit(syn::ExprLit {
3650 lit: syn::Lit::Str(lit_str),
3651 ..
3652 }) = &nv.value
3653 {
3654 model = Some(lit_str.value());
3655 }
3656 }
3657 Meta::NameValue(nv) if nv.path.is_ident("inner") => {
3658 if let syn::Expr::Lit(syn::ExprLit {
3659 lit: syn::Lit::Str(lit_str),
3660 ..
3661 }) = &nv.value
3662 {
3663 inner = Some(lit_str.value());
3664 }
3665 }
3666 Meta::NameValue(nv) if nv.path.is_ident("default_inner") => {
3667 if let syn::Expr::Lit(syn::ExprLit {
3668 lit: syn::Lit::Str(lit_str),
3669 ..
3670 }) = &nv.value
3671 {
3672 default_inner = Some(lit_str.value());
3673 }
3674 }
3675 Meta::NameValue(nv) if nv.path.is_ident("max_retries") => {
3676 if let syn::Expr::Lit(syn::ExprLit {
3677 lit: syn::Lit::Int(lit_int),
3678 ..
3679 }) = &nv.value
3680 {
3681 max_retries = Some(lit_int.base10_parse()?);
3682 }
3683 }
3684 Meta::NameValue(nv) if nv.path.is_ident("profile") => {
3685 if let syn::Expr::Lit(syn::ExprLit {
3686 lit: syn::Lit::Str(lit_str),
3687 ..
3688 }) = &nv.value
3689 {
3690 profile = Some(lit_str.value());
3691 }
3692 }
3693 _ => {}
3694 }
3695 }
3696
3697 Ok(AgentAttrs {
3698 expertise,
3699 output,
3700 backend,
3701 model,
3702 inner,
3703 default_inner,
3704 max_retries,
3705 profile,
3706 })
3707 }
3708}
3709
3710fn parse_agent_attrs(attrs: &[syn::Attribute]) -> syn::Result<AgentAttrs> {
3712 for attr in attrs {
3713 if attr.path().is_ident("agent") {
3714 return attr.parse_args::<AgentAttrs>();
3715 }
3716 }
3717
3718 Ok(AgentAttrs {
3719 expertise: None,
3720 output: None,
3721 backend: None,
3722 model: None,
3723 inner: None,
3724 default_inner: None,
3725 max_retries: None,
3726 profile: None,
3727 })
3728}
3729
3730fn generate_backend_constructors(
3732 struct_name: &syn::Ident,
3733 backend: &str,
3734 _model: Option<&str>,
3735 _profile: Option<&str>,
3736 crate_path: &proc_macro2::TokenStream,
3737) -> proc_macro2::TokenStream {
3738 match backend {
3739 "claude" => {
3740 quote! {
3741 impl #struct_name {
3742 pub fn with_claude() -> Self {
3744 Self::new(#crate_path::agent::impls::ClaudeCodeAgent::new())
3745 }
3746
3747 pub fn with_claude_model(model: &str) -> Self {
3749 Self::new(
3750 #crate_path::agent::impls::ClaudeCodeAgent::new()
3751 .with_model_str(model)
3752 )
3753 }
3754 }
3755 }
3756 }
3757 "gemini" => {
3758 quote! {
3759 impl #struct_name {
3760 pub fn with_gemini() -> Self {
3762 Self::new(#crate_path::agent::impls::GeminiAgent::new())
3763 }
3764
3765 pub fn with_gemini_model(model: &str) -> Self {
3767 Self::new(
3768 #crate_path::agent::impls::GeminiAgent::new()
3769 .with_model_str(model)
3770 )
3771 }
3772 }
3773 }
3774 }
3775 _ => quote! {},
3776 }
3777}
3778
3779fn generate_default_impl(
3781 struct_name: &syn::Ident,
3782 backend: &str,
3783 model: Option<&str>,
3784 profile: Option<&str>,
3785 crate_path: &proc_macro2::TokenStream,
3786) -> proc_macro2::TokenStream {
3787 let profile_expr = if let Some(profile_str) = profile {
3789 match profile_str.to_lowercase().as_str() {
3790 "creative" => quote! { #crate_path::agent::ExecutionProfile::Creative },
3791 "balanced" => quote! { #crate_path::agent::ExecutionProfile::Balanced },
3792 "deterministic" => quote! { #crate_path::agent::ExecutionProfile::Deterministic },
3793 _ => quote! { #crate_path::agent::ExecutionProfile::Balanced }, }
3795 } else {
3796 quote! { #crate_path::agent::ExecutionProfile::default() }
3797 };
3798
3799 let agent_init = match backend {
3800 "gemini" => {
3801 let mut builder = quote! { #crate_path::agent::impls::GeminiAgent::new() };
3802
3803 if let Some(model_str) = model {
3804 builder = quote! { #builder.with_model_str(#model_str) };
3805 }
3806
3807 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3808 builder
3809 }
3810 _ => {
3811 let mut builder = quote! { #crate_path::agent::impls::ClaudeCodeAgent::new() };
3813
3814 if let Some(model_str) = model {
3815 builder = quote! { #builder.with_model_str(#model_str) };
3816 }
3817
3818 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3819 builder
3820 }
3821 };
3822
3823 quote! {
3824 impl Default for #struct_name {
3825 fn default() -> Self {
3826 Self::new(#agent_init)
3827 }
3828 }
3829 }
3830}
3831
3832#[proc_macro_derive(Agent, attributes(agent))]
3841pub fn derive_agent(input: TokenStream) -> TokenStream {
3842 let input = parse_macro_input!(input as DeriveInput);
3843 let struct_name = &input.ident;
3844
3845 let agent_attrs = match parse_agent_attrs(&input.attrs) {
3847 Ok(attrs) => attrs,
3848 Err(e) => return e.to_compile_error().into(),
3849 };
3850
3851 let expertise = agent_attrs
3852 .expertise
3853 .unwrap_or_else(|| String::from("general AI assistant"));
3854 let output_type = agent_attrs
3855 .output
3856 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3857 let backend = agent_attrs
3858 .backend
3859 .unwrap_or_else(|| String::from("claude"));
3860 let model = agent_attrs.model;
3861 let _profile = agent_attrs.profile; let max_retries = agent_attrs.max_retries.unwrap_or(3); let found_crate =
3866 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3867 let crate_path = match found_crate {
3868 FoundCrate::Itself => {
3869 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3871 quote!(::#ident)
3872 }
3873 FoundCrate::Name(name) => {
3874 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3875 quote!(::#ident)
3876 }
3877 };
3878
3879 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
3880
3881 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3883 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3884
3885 let enhanced_expertise = if is_string_output {
3887 quote! { #expertise }
3889 } else {
3890 let type_name = quote!(#output_type).to_string();
3892 quote! {
3893 {
3894 use std::sync::OnceLock;
3895 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3896
3897 EXPERTISE_CACHE.get_or_init(|| {
3898 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3900
3901 if schema.is_empty() {
3902 format!(
3904 concat!(
3905 #expertise,
3906 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3907 "Do not include any text outside the JSON object."
3908 ),
3909 #type_name
3910 )
3911 } else {
3912 format!(
3914 concat!(
3915 #expertise,
3916 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3917 ),
3918 schema
3919 )
3920 }
3921 }).as_str()
3922 }
3923 }
3924 };
3925
3926 let agent_init = match backend.as_str() {
3928 "gemini" => {
3929 if let Some(model_str) = model {
3930 quote! {
3931 use #crate_path::agent::impls::GeminiAgent;
3932 let agent = GeminiAgent::new().with_model_str(#model_str);
3933 }
3934 } else {
3935 quote! {
3936 use #crate_path::agent::impls::GeminiAgent;
3937 let agent = GeminiAgent::new();
3938 }
3939 }
3940 }
3941 "claude" => {
3942 if let Some(model_str) = model {
3943 quote! {
3944 use #crate_path::agent::impls::ClaudeCodeAgent;
3945 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3946 }
3947 } else {
3948 quote! {
3949 use #crate_path::agent::impls::ClaudeCodeAgent;
3950 let agent = ClaudeCodeAgent::new();
3951 }
3952 }
3953 }
3954 _ => {
3955 if let Some(model_str) = model {
3957 quote! {
3958 use #crate_path::agent::impls::ClaudeCodeAgent;
3959 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3960 }
3961 } else {
3962 quote! {
3963 use #crate_path::agent::impls::ClaudeCodeAgent;
3964 let agent = ClaudeCodeAgent::new();
3965 }
3966 }
3967 }
3968 };
3969
3970 let expanded = quote! {
3971 #[async_trait::async_trait]
3972 impl #impl_generics #crate_path::agent::Agent for #struct_name #ty_generics #where_clause {
3973 type Output = #output_type;
3974
3975 fn expertise(&self) -> &str {
3976 #enhanced_expertise
3977 }
3978
3979 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3980 #agent_init
3982
3983 let agent_ref = &agent;
3985 #crate_path::agent::retry::retry_execution(
3986 #max_retries,
3987 &intent,
3988 move |payload| {
3989 let payload = payload.clone();
3990 async move {
3991 let response = agent_ref.execute(payload).await?;
3993
3994 let json_str = #crate_path::extract_json(&response)
3996 .map_err(|e| #crate_path::agent::AgentError::ParseError {
3997 message: format!("Failed to extract JSON: {}", e),
3998 reason: #crate_path::agent::error::ParseErrorReason::MarkdownExtractionFailed,
3999 })?;
4000
4001 serde_json::from_str::<Self::Output>(&json_str)
4003 .map_err(|e| {
4004 let reason = if e.is_eof() {
4006 #crate_path::agent::error::ParseErrorReason::UnexpectedEof
4007 } else if e.is_syntax() {
4008 #crate_path::agent::error::ParseErrorReason::InvalidJson
4009 } else {
4010 #crate_path::agent::error::ParseErrorReason::SchemaMismatch
4011 };
4012
4013 #crate_path::agent::AgentError::ParseError {
4014 message: format!("Failed to parse JSON: {}", e),
4015 reason,
4016 }
4017 })
4018 }
4019 }
4020 ).await
4021 }
4022
4023 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
4024 #agent_init
4026 agent.is_available().await
4027 }
4028 }
4029 };
4030
4031 TokenStream::from(expanded)
4032}
4033
4034#[proc_macro_attribute]
4049pub fn agent(attr: TokenStream, item: TokenStream) -> TokenStream {
4050 let agent_attrs = match syn::parse::<AgentAttrs>(attr) {
4052 Ok(attrs) => attrs,
4053 Err(e) => return e.to_compile_error().into(),
4054 };
4055
4056 let input = parse_macro_input!(item as DeriveInput);
4058 let struct_name = &input.ident;
4059 let vis = &input.vis;
4060
4061 let expertise = agent_attrs
4062 .expertise
4063 .unwrap_or_else(|| String::from("general AI assistant"));
4064 let output_type = agent_attrs
4065 .output
4066 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
4067 let backend = agent_attrs
4068 .backend
4069 .unwrap_or_else(|| String::from("claude"));
4070 let model = agent_attrs.model;
4071 let profile = agent_attrs.profile;
4072
4073 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
4075 let is_string_output = output_type_str == "String" || output_type_str == "&str";
4076
4077 let found_crate =
4079 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
4080 let crate_path = match found_crate {
4081 FoundCrate::Itself => {
4082 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
4083 quote!(::#ident)
4084 }
4085 FoundCrate::Name(name) => {
4086 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
4087 quote!(::#ident)
4088 }
4089 };
4090
4091 let inner_generic_name = agent_attrs.inner.unwrap_or_else(|| String::from("A"));
4093 let inner_generic_ident = syn::Ident::new(&inner_generic_name, proc_macro2::Span::call_site());
4094
4095 let default_agent_type = if let Some(ref custom_type) = agent_attrs.default_inner {
4097 let type_path: syn::Type =
4099 syn::parse_str(custom_type).expect("default_inner must be a valid type path");
4100 quote! { #type_path }
4101 } else {
4102 match backend.as_str() {
4104 "gemini" => quote! { #crate_path::agent::impls::GeminiAgent },
4105 _ => quote! { #crate_path::agent::impls::ClaudeCodeAgent },
4106 }
4107 };
4108
4109 let struct_def = quote! {
4111 #vis struct #struct_name<#inner_generic_ident = #default_agent_type> {
4112 inner: #inner_generic_ident,
4113 }
4114 };
4115
4116 let constructors = quote! {
4118 impl<#inner_generic_ident> #struct_name<#inner_generic_ident> {
4119 pub fn new(inner: #inner_generic_ident) -> Self {
4121 Self { inner }
4122 }
4123 }
4124 };
4125
4126 let (backend_constructors, default_impl) = if agent_attrs.default_inner.is_some() {
4128 let default_impl = quote! {
4130 impl Default for #struct_name {
4131 fn default() -> Self {
4132 Self {
4133 inner: <#default_agent_type as Default>::default(),
4134 }
4135 }
4136 }
4137 };
4138 (quote! {}, default_impl)
4139 } else {
4140 let backend_constructors = generate_backend_constructors(
4142 struct_name,
4143 &backend,
4144 model.as_deref(),
4145 profile.as_deref(),
4146 &crate_path,
4147 );
4148 let default_impl = generate_default_impl(
4149 struct_name,
4150 &backend,
4151 model.as_deref(),
4152 profile.as_deref(),
4153 &crate_path,
4154 );
4155 (backend_constructors, default_impl)
4156 };
4157
4158 let enhanced_expertise = if is_string_output {
4160 quote! { #expertise }
4162 } else {
4163 let type_name = quote!(#output_type).to_string();
4165 quote! {
4166 {
4167 use std::sync::OnceLock;
4168 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
4169
4170 EXPERTISE_CACHE.get_or_init(|| {
4171 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
4173
4174 if schema.is_empty() {
4175 format!(
4177 concat!(
4178 #expertise,
4179 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
4180 "Do not include any text outside the JSON object."
4181 ),
4182 #type_name
4183 )
4184 } else {
4185 format!(
4187 concat!(
4188 #expertise,
4189 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
4190 ),
4191 schema
4192 )
4193 }
4194 }).as_str()
4195 }
4196 }
4197 };
4198
4199 let agent_impl = quote! {
4201 #[async_trait::async_trait]
4202 impl<#inner_generic_ident> #crate_path::agent::Agent for #struct_name<#inner_generic_ident>
4203 where
4204 #inner_generic_ident: #crate_path::agent::Agent<Output = String>,
4205 {
4206 type Output = #output_type;
4207
4208 fn expertise(&self) -> &str {
4209 #enhanced_expertise
4210 }
4211
4212 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
4213 let enhanced_payload = intent.prepend_text(self.expertise());
4215
4216 let response = self.inner.execute(enhanced_payload).await?;
4218
4219 let json_str = #crate_path::extract_json(&response)
4221 .map_err(|e| #crate_path::agent::AgentError::ParseError {
4222 message: e.to_string(),
4223 reason: #crate_path::agent::error::ParseErrorReason::MarkdownExtractionFailed,
4224 })?;
4225
4226 serde_json::from_str(&json_str).map_err(|e| {
4228 let reason = if e.is_eof() {
4229 #crate_path::agent::error::ParseErrorReason::UnexpectedEof
4230 } else if e.is_syntax() {
4231 #crate_path::agent::error::ParseErrorReason::InvalidJson
4232 } else {
4233 #crate_path::agent::error::ParseErrorReason::SchemaMismatch
4234 };
4235 #crate_path::agent::AgentError::ParseError {
4236 message: e.to_string(),
4237 reason,
4238 }
4239 })
4240 }
4241
4242 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
4243 self.inner.is_available().await
4244 }
4245 }
4246 };
4247
4248 let expanded = quote! {
4249 #struct_def
4250 #constructors
4251 #backend_constructors
4252 #default_impl
4253 #agent_impl
4254 };
4255
4256 TokenStream::from(expanded)
4257}
4258
4259#[proc_macro_derive(TypeMarker)]
4281pub fn derive_type_marker(input: TokenStream) -> TokenStream {
4282 let input = parse_macro_input!(input as DeriveInput);
4283 let struct_name = &input.ident;
4284 let type_name_str = struct_name.to_string();
4285
4286 let found_crate =
4288 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
4289 let crate_path = match found_crate {
4290 FoundCrate::Itself => {
4291 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
4292 quote!(::#ident)
4293 }
4294 FoundCrate::Name(name) => {
4295 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
4296 quote!(::#ident)
4297 }
4298 };
4299
4300 let expanded = quote! {
4301 impl #crate_path::orchestrator::TypeMarker for #struct_name {
4302 const TYPE_NAME: &'static str = #type_name_str;
4303 }
4304 };
4305
4306 TokenStream::from(expanded)
4307}
4308
4309#[proc_macro_attribute]
4345pub fn type_marker(_attr: TokenStream, item: TokenStream) -> TokenStream {
4346 let input = parse_macro_input!(item as syn::DeriveInput);
4347 let struct_name = &input.ident;
4348 let vis = &input.vis;
4349 let type_name_str = struct_name.to_string();
4350
4351 let default_fn_name = syn::Ident::new(
4353 &format!("default_{}_type", to_snake_case(&type_name_str)),
4354 struct_name.span(),
4355 );
4356
4357 let found_crate =
4359 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
4360 let crate_path = match found_crate {
4361 FoundCrate::Itself => {
4362 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
4363 quote!(::#ident)
4364 }
4365 FoundCrate::Name(name) => {
4366 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
4367 quote!(::#ident)
4368 }
4369 };
4370
4371 let fields = match &input.data {
4373 syn::Data::Struct(data_struct) => match &data_struct.fields {
4374 syn::Fields::Named(fields) => &fields.named,
4375 _ => {
4376 return syn::Error::new_spanned(
4377 struct_name,
4378 "type_marker only works with structs with named fields",
4379 )
4380 .to_compile_error()
4381 .into();
4382 }
4383 },
4384 _ => {
4385 return syn::Error::new_spanned(struct_name, "type_marker only works with structs")
4386 .to_compile_error()
4387 .into();
4388 }
4389 };
4390
4391 let mut new_fields = vec![];
4393
4394 let default_fn_name_str = default_fn_name.to_string();
4396 let default_fn_name_lit = syn::LitStr::new(&default_fn_name_str, default_fn_name.span());
4397
4398 new_fields.push(quote! {
4403 #[serde(default = #default_fn_name_lit)]
4404 __type: String
4405 });
4406
4407 for field in fields {
4409 new_fields.push(quote! { #field });
4410 }
4411
4412 let attrs = &input.attrs;
4414 let generics = &input.generics;
4415
4416 let expanded = quote! {
4417 fn #default_fn_name() -> String {
4419 #type_name_str.to_string()
4420 }
4421
4422 #(#attrs)*
4424 #vis struct #struct_name #generics {
4425 #(#new_fields),*
4426 }
4427
4428 impl #crate_path::orchestrator::TypeMarker for #struct_name {
4430 const TYPE_NAME: &'static str = #type_name_str;
4431 }
4432 };
4433
4434 TokenStream::from(expanded)
4435}