1use proc_macro::TokenStream;
2use proc_macro_crate::{FoundCrate, crate_name};
3use quote::quote;
4use regex::Regex;
5use syn::{
6 Data, DeriveInput, Meta, Token,
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 punctuated::Punctuated,
10};
11
12fn parse_template_placeholders_with_mode(template: &str) -> Vec<(String, Option<String>)> {
15 let mut placeholders = Vec::new();
16 let mut seen_fields = std::collections::HashSet::new();
17
18 let mode_pattern = Regex::new(r"\{\{\s*(\w+)\s*:\s*(\w+)\s*\}\}").unwrap();
20 for cap in mode_pattern.captures_iter(template) {
21 let field_name = cap[1].to_string();
22 let mode = cap[2].to_string();
23 placeholders.push((field_name.clone(), Some(mode)));
24 seen_fields.insert(field_name);
25 }
26
27 let standard_pattern = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
29 for cap in standard_pattern.captures_iter(template) {
30 let field_name = cap[1].to_string();
31 if !seen_fields.contains(&field_name) {
33 placeholders.push((field_name, None));
34 }
35 }
36
37 placeholders
38}
39
40fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
42 attrs
43 .iter()
44 .filter_map(|attr| {
45 if attr.path().is_ident("doc")
46 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
47 && let syn::Expr::Lit(syn::ExprLit {
48 lit: syn::Lit::Str(lit_str),
49 ..
50 }) = &meta_name_value.value
51 {
52 return Some(lit_str.value());
53 }
54 None
55 })
56 .map(|s| s.trim().to_string())
57 .collect::<Vec<_>>()
58 .join(" ")
59}
60
61fn generate_example_only_parts(
63 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
64 has_default: bool,
65 crate_path: &proc_macro2::TokenStream,
66) -> proc_macro2::TokenStream {
67 let mut field_values = Vec::new();
68
69 for field in fields.iter() {
70 let field_name = field.ident.as_ref().unwrap();
71 let field_name_str = field_name.to_string();
72 let attrs = parse_field_prompt_attrs(&field.attrs);
73
74 if field_name_str == "__type" {
78 continue;
79 }
80
81 if attrs.skip {
83 continue;
84 }
85
86 if let Some(example) = attrs.example {
88 field_values.push(quote! {
90 json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
91 });
92 } else if has_default {
93 field_values.push(quote! {
95 let default_value = serde_json::to_value(&default_instance.#field_name)
96 .unwrap_or(serde_json::Value::Null);
97 json_obj.insert(#field_name_str.to_string(), default_value);
98 });
99 } else {
100 field_values.push(quote! {
102 let value = serde_json::to_value(&self.#field_name)
103 .unwrap_or(serde_json::Value::Null);
104 json_obj.insert(#field_name_str.to_string(), value);
105 });
106 }
107 }
108
109 if has_default {
110 quote! {
111 {
112 let default_instance = Self::default();
113 let mut json_obj = serde_json::Map::new();
114 #(#field_values)*
115 let json_value = serde_json::Value::Object(json_obj);
116 let json_str = serde_json::to_string_pretty(&json_value)
117 .unwrap_or_else(|_| "{}".to_string());
118 vec![#crate_path::prompt::PromptPart::Text(json_str)]
119 }
120 }
121 } else {
122 quote! {
123 {
124 let mut json_obj = serde_json::Map::new();
125 #(#field_values)*
126 let json_value = serde_json::Value::Object(json_obj);
127 let json_str = serde_json::to_string_pretty(&json_value)
128 .unwrap_or_else(|_| "{}".to_string());
129 vec![#crate_path::prompt::PromptPart::Text(json_str)]
130 }
131 }
132 }
133}
134
135fn generate_schema_only_parts(
137 struct_name: &str,
138 struct_docs: &str,
139 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
140 crate_path: &proc_macro2::TokenStream,
141 _has_type_marker: bool,
142) -> proc_macro2::TokenStream {
143 let mut field_schema_parts = vec![];
144 let mut nested_type_collectors = vec![];
145
146 for field in fields.iter() {
148 let field_name = field.ident.as_ref().unwrap();
149 let field_name_str = field_name.to_string();
150 let attrs = parse_field_prompt_attrs(&field.attrs);
151
152 if field_name_str == "__type" {
156 continue;
157 }
158
159 if attrs.skip {
161 continue;
162 }
163
164 let field_docs = extract_doc_comments(&field.attrs);
166
167 let (is_vec, inner_type) = extract_vec_inner_type(&field.ty);
169
170 if is_vec {
171 let comment = if !field_docs.is_empty() {
174 format!(" // {}", field_docs)
175 } else {
176 String::new()
177 };
178
179 field_schema_parts.push(quote! {
180 {
181 let type_name = stringify!(#inner_type);
182 format!(" {}: {}[];{}", #field_name_str, type_name, #comment)
183 }
184 });
185
186 if let Some(inner) = inner_type
188 && !is_primitive_type(inner)
189 {
190 nested_type_collectors.push(quote! {
191 <#inner as #crate_path::prompt::ToPrompt>::prompt_schema()
192 });
193 }
194 } else {
195 let field_type = &field.ty;
197 let is_primitive = is_primitive_type(field_type);
198
199 if !is_primitive {
200 let comment = if !field_docs.is_empty() {
203 format!(" // {}", field_docs)
204 } else {
205 String::new()
206 };
207
208 field_schema_parts.push(quote! {
209 {
210 let type_name = stringify!(#field_type);
211 format!(" {}: {};{}", #field_name_str, type_name, #comment)
212 }
213 });
214
215 nested_type_collectors.push(quote! {
217 <#field_type as #crate_path::prompt::ToPrompt>::prompt_schema()
218 });
219 } else {
220 let type_str = format_type_for_schema(&field.ty);
223 let comment = if !field_docs.is_empty() {
224 format!(" // {}", field_docs)
225 } else {
226 String::new()
227 };
228
229 field_schema_parts.push(quote! {
230 format!(" {}: {};{}", #field_name_str, #type_str, #comment)
231 });
232 }
233 }
234 }
235
236 let mut header_lines = Vec::new();
251
252 if !struct_docs.is_empty() {
254 header_lines.push("/**".to_string());
255 header_lines.push(format!(" * {}", struct_docs));
256 header_lines.push(" */".to_string());
257 }
258
259 header_lines.push(format!("type {} = {{", struct_name));
261
262 quote! {
263 {
264 let mut all_lines: Vec<String> = Vec::new();
265
266 let nested_schemas: Vec<String> = vec![#(#nested_type_collectors),*];
268 let mut seen_types = std::collections::HashSet::<String>::new();
269
270 for schema in nested_schemas {
271 if !schema.is_empty() {
272 if seen_types.insert(schema.clone()) {
274 all_lines.push(schema);
275 all_lines.push(String::new()); }
277 }
278 }
279
280 let mut lines: Vec<String> = Vec::new();
282 #(lines.push(#header_lines.to_string());)*
283 #(lines.push(#field_schema_parts);)*
284 lines.push("}".to_string());
285 all_lines.push(lines.join("\n"));
286
287 vec![#crate_path::prompt::PromptPart::Text(all_lines.join("\n"))]
288 }
289 }
290}
291
292fn extract_vec_inner_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
294 if let syn::Type::Path(type_path) = ty
295 && let Some(last_segment) = type_path.path.segments.last()
296 && last_segment.ident == "Vec"
297 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
298 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
299 {
300 return (true, Some(inner_type));
301 }
302 (false, None)
303}
304
305fn is_primitive_type(ty: &syn::Type) -> bool {
307 if let syn::Type::Path(type_path) = ty
308 && let Some(last_segment) = type_path.path.segments.last()
309 {
310 let type_name = last_segment.ident.to_string();
311 matches!(
312 type_name.as_str(),
313 "String"
314 | "str"
315 | "i8"
316 | "i16"
317 | "i32"
318 | "i64"
319 | "i128"
320 | "isize"
321 | "u8"
322 | "u16"
323 | "u32"
324 | "u64"
325 | "u128"
326 | "usize"
327 | "f32"
328 | "f64"
329 | "bool"
330 | "Vec"
331 | "Option"
332 | "HashMap"
333 | "BTreeMap"
334 | "HashSet"
335 | "BTreeSet"
336 )
337 } else {
338 true
340 }
341}
342
343fn format_type_for_schema(ty: &syn::Type) -> String {
345 match ty {
347 syn::Type::Path(type_path) => {
348 let path = &type_path.path;
349 if let Some(last_segment) = path.segments.last() {
350 let type_name = last_segment.ident.to_string();
351
352 if type_name == "Option"
354 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
355 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
356 {
357 return format!("{} | null", format_type_for_schema(inner_type));
358 }
359
360 match type_name.as_str() {
362 "String" | "str" => "string".to_string(),
363 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
364 | "u64" | "u128" | "usize" => "number".to_string(),
365 "f32" | "f64" => "number".to_string(),
366 "bool" => "boolean".to_string(),
367 "Vec" => {
368 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
369 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
370 {
371 return format!("{}[]", format_type_for_schema(inner_type));
372 }
373 "array".to_string()
374 }
375 _ => type_name.to_lowercase(),
376 }
377 } else {
378 "unknown".to_string()
379 }
380 }
381 _ => "unknown".to_string(),
382 }
383}
384
385#[derive(Default)]
387struct PromptAttributes {
388 skip: bool,
389 rename: Option<String>,
390 description: Option<String>,
391}
392
393fn parse_prompt_attributes(attrs: &[syn::Attribute]) -> PromptAttributes {
396 let mut result = PromptAttributes::default();
397
398 for attr in attrs {
399 if attr.path().is_ident("prompt") {
400 if let Ok(meta_list) = attr.meta.require_list() {
402 if let Ok(metas) =
404 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
405 {
406 for meta in metas {
407 if let Meta::NameValue(nv) = meta {
408 if nv.path.is_ident("rename") {
409 if let syn::Expr::Lit(syn::ExprLit {
410 lit: syn::Lit::Str(lit_str),
411 ..
412 }) = nv.value
413 {
414 result.rename = Some(lit_str.value());
415 }
416 } else if nv.path.is_ident("description")
417 && let syn::Expr::Lit(syn::ExprLit {
418 lit: syn::Lit::Str(lit_str),
419 ..
420 }) = nv.value
421 {
422 result.description = Some(lit_str.value());
423 }
424 } else if let Meta::Path(path) = meta
425 && path.is_ident("skip")
426 {
427 result.skip = true;
428 }
429 }
430 }
431
432 let tokens_str = meta_list.tokens.to_string();
434 if tokens_str == "skip" {
435 result.skip = true;
436 }
437 }
438
439 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
441 result.description = Some(lit_str.value());
442 }
443 }
444 }
445 result
446}
447
448fn parse_serde_variant_rename(attrs: &[syn::Attribute]) -> Option<String> {
450 for attr in attrs {
451 if attr.path().is_ident("serde")
452 && let Ok(meta_list) = attr.meta.require_list()
453 && let Ok(metas) =
454 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
455 {
456 for meta in metas {
457 if let Meta::NameValue(nv) = meta
458 && nv.path.is_ident("rename")
459 && let syn::Expr::Lit(syn::ExprLit {
460 lit: syn::Lit::Str(lit_str),
461 ..
462 }) = nv.value
463 {
464 return Some(lit_str.value());
465 }
466 }
467 }
468 }
469 None
470}
471
472#[derive(Debug, Clone, Copy, PartialEq, Eq)]
474enum RenameRule {
475 #[allow(dead_code)]
476 None,
477 LowerCase,
478 UpperCase,
479 PascalCase,
480 CamelCase,
481 SnakeCase,
482 ScreamingSnakeCase,
483 KebabCase,
484 ScreamingKebabCase,
485}
486
487impl RenameRule {
488 fn from_str(s: &str) -> Option<Self> {
490 match s {
491 "lowercase" => Some(Self::LowerCase),
492 "UPPERCASE" => Some(Self::UpperCase),
493 "PascalCase" => Some(Self::PascalCase),
494 "camelCase" => Some(Self::CamelCase),
495 "snake_case" => Some(Self::SnakeCase),
496 "SCREAMING_SNAKE_CASE" => Some(Self::ScreamingSnakeCase),
497 "kebab-case" => Some(Self::KebabCase),
498 "SCREAMING-KEBAB-CASE" => Some(Self::ScreamingKebabCase),
499 _ => None,
500 }
501 }
502
503 fn apply(&self, name: &str) -> String {
505 match self {
506 Self::None => name.to_string(),
507 Self::LowerCase => name.to_lowercase(),
508 Self::UpperCase => name.to_uppercase(),
509 Self::PascalCase => name.to_string(), Self::CamelCase => {
511 let mut chars = name.chars();
513 match chars.next() {
514 None => String::new(),
515 Some(first) => first.to_lowercase().chain(chars).collect(),
516 }
517 }
518 Self::SnakeCase => {
519 let mut result = String::new();
521 for (i, ch) in name.chars().enumerate() {
522 if ch.is_uppercase() && i > 0 {
523 result.push('_');
524 }
525 result.push(ch.to_lowercase().next().unwrap());
526 }
527 result
528 }
529 Self::ScreamingSnakeCase => {
530 let mut result = String::new();
532 for (i, ch) in name.chars().enumerate() {
533 if ch.is_uppercase() && i > 0 {
534 result.push('_');
535 }
536 result.push(ch.to_uppercase().next().unwrap());
537 }
538 result
539 }
540 Self::KebabCase => {
541 let mut result = String::new();
543 for (i, ch) in name.chars().enumerate() {
544 if ch.is_uppercase() && i > 0 {
545 result.push('-');
546 }
547 result.push(ch.to_lowercase().next().unwrap());
548 }
549 result
550 }
551 Self::ScreamingKebabCase => {
552 let mut result = String::new();
554 for (i, ch) in name.chars().enumerate() {
555 if ch.is_uppercase() && i > 0 {
556 result.push('-');
557 }
558 result.push(ch.to_uppercase().next().unwrap());
559 }
560 result
561 }
562 }
563 }
564}
565
566fn parse_serde_rename_all(attrs: &[syn::Attribute]) -> Option<RenameRule> {
568 for attr in attrs {
569 if attr.path().is_ident("serde")
570 && let Ok(meta_list) = attr.meta.require_list()
571 {
572 if let Ok(metas) =
574 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
575 {
576 for meta in metas {
577 if let Meta::NameValue(nv) = meta
578 && nv.path.is_ident("rename_all")
579 && let syn::Expr::Lit(syn::ExprLit {
580 lit: syn::Lit::Str(lit_str),
581 ..
582 }) = nv.value
583 {
584 return RenameRule::from_str(&lit_str.value());
585 }
586 }
587 }
588 }
589 }
590 None
591}
592
593#[derive(Debug, Default)]
595struct FieldPromptAttrs {
596 skip: bool,
597 rename: Option<String>,
598 format_with: Option<String>,
599 image: bool,
600 example: Option<String>,
601}
602
603fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
605 let mut result = FieldPromptAttrs::default();
606
607 for attr in attrs {
608 if attr.path().is_ident("prompt") {
609 if let Ok(meta_list) = attr.meta.require_list() {
611 if let Ok(metas) =
613 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
614 {
615 for meta in metas {
616 match meta {
617 Meta::Path(path) if path.is_ident("skip") => {
618 result.skip = true;
619 }
620 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
621 if let syn::Expr::Lit(syn::ExprLit {
622 lit: syn::Lit::Str(lit_str),
623 ..
624 }) = nv.value
625 {
626 result.rename = Some(lit_str.value());
627 }
628 }
629 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
630 if let syn::Expr::Lit(syn::ExprLit {
631 lit: syn::Lit::Str(lit_str),
632 ..
633 }) = nv.value
634 {
635 result.format_with = Some(lit_str.value());
636 }
637 }
638 Meta::Path(path) if path.is_ident("image") => {
639 result.image = true;
640 }
641 Meta::NameValue(nv) if nv.path.is_ident("example") => {
642 if let syn::Expr::Lit(syn::ExprLit {
643 lit: syn::Lit::Str(lit_str),
644 ..
645 }) = nv.value
646 {
647 result.example = Some(lit_str.value());
648 }
649 }
650 _ => {}
651 }
652 }
653 } else if meta_list.tokens.to_string() == "skip" {
654 result.skip = true;
656 } else if meta_list.tokens.to_string() == "image" {
657 result.image = true;
659 }
660 }
661 }
662 }
663
664 result
665}
666
667#[proc_macro_derive(ToPrompt, attributes(prompt))]
710pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
711 let input = parse_macro_input!(input as DeriveInput);
712
713 let found_crate =
714 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
715 let crate_path = match found_crate {
716 FoundCrate::Itself => {
717 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
719 quote!(::#ident)
720 }
721 FoundCrate::Name(name) => {
722 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
723 quote!(::#ident)
724 }
725 };
726
727 match &input.data {
729 Data::Enum(data_enum) => {
730 let enum_name = &input.ident;
732 let enum_docs = extract_doc_comments(&input.attrs);
733
734 let rename_rule = parse_serde_rename_all(&input.attrs);
736
737 let mut variant_lines = Vec::new();
750 let mut first_variant_name = None;
751
752 for variant in &data_enum.variants {
753 let variant_name = &variant.ident;
754 let variant_name_str = variant_name.to_string();
755
756 let prompt_attrs = parse_prompt_attributes(&variant.attrs);
758
759 if prompt_attrs.skip {
761 continue;
762 }
763
764 let variant_value = if let Some(prompt_rename) = &prompt_attrs.rename {
770 prompt_rename.clone()
771 } else if let Some(serde_rename) = parse_serde_variant_rename(&variant.attrs) {
772 serde_rename
773 } else if let Some(rule) = rename_rule {
774 rule.apply(&variant_name_str)
775 } else {
776 variant_name_str.clone()
777 };
778
779 let variant_line = if let Some(desc) = &prompt_attrs.description {
781 format!(" | \"{}\" // {}", variant_value, desc)
782 } else {
783 let docs = extract_doc_comments(&variant.attrs);
784 if !docs.is_empty() {
785 format!(" | \"{}\" // {}", variant_value, docs)
786 } else {
787 format!(" | \"{}\"", variant_value)
788 }
789 };
790
791 variant_lines.push(variant_line);
792
793 if first_variant_name.is_none() {
794 first_variant_name = Some(variant_value);
795 }
796 }
797
798 let mut lines = Vec::new();
800
801 if !enum_docs.is_empty() {
803 lines.push("/**".to_string());
804 lines.push(format!(" * {}", enum_docs));
805 lines.push(" */".to_string());
806 }
807
808 lines.push(format!("type {} =", enum_name));
810
811 for line in &variant_lines {
813 lines.push(line.clone());
814 }
815
816 if let Some(last) = lines.last_mut()
818 && !last.ends_with(';')
819 {
820 last.push(';');
821 }
822
823 if let Some(first_name) = first_variant_name {
825 lines.push("".to_string()); lines.push(format!("Example value: \"{}\"", first_name));
827 }
828
829 let prompt_string = lines.join("\n");
830 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
831
832 let mut match_arms = Vec::new();
834 for variant in &data_enum.variants {
835 let variant_name = &variant.ident;
836 let variant_name_str = variant_name.to_string();
837
838 let prompt_attrs = parse_prompt_attributes(&variant.attrs);
840
841 let variant_value = if let Some(prompt_rename) = &prompt_attrs.rename {
847 prompt_rename.clone()
848 } else if let Some(serde_rename) = parse_serde_variant_rename(&variant.attrs) {
849 serde_rename
850 } else if let Some(rule) = rename_rule {
851 rule.apply(&variant_name_str)
852 } else {
853 variant_name_str.clone()
854 };
855
856 if prompt_attrs.skip {
858 match_arms.push(quote! {
860 Self::#variant_name => stringify!(#variant_name).to_string()
861 });
862 } else if let Some(desc) = &prompt_attrs.description {
863 match_arms.push(quote! {
865 Self::#variant_name => format!("{}: {}", #variant_value, #desc)
866 });
867 } else {
868 let variant_docs = extract_doc_comments(&variant.attrs);
870 if !variant_docs.is_empty() {
871 match_arms.push(quote! {
872 Self::#variant_name => format!("{}: {}", #variant_value, #variant_docs)
873 });
874 } else {
875 match_arms.push(quote! {
876 Self::#variant_name => #variant_value.to_string()
877 });
878 }
879 }
880 }
881
882 let to_prompt_impl = if match_arms.is_empty() {
883 quote! {
885 fn to_prompt(&self) -> String {
886 match *self {}
887 }
888 }
889 } else {
890 quote! {
891 fn to_prompt(&self) -> String {
892 match self {
893 #(#match_arms),*
894 }
895 }
896 }
897 };
898
899 let expanded = quote! {
900 impl #impl_generics #crate_path::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
901 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
902 vec![#crate_path::prompt::PromptPart::Text(self.to_prompt())]
903 }
904
905 #to_prompt_impl
906
907 fn prompt_schema() -> String {
908 #prompt_string.to_string()
909 }
910 }
911 };
912
913 TokenStream::from(expanded)
914 }
915 Data::Struct(data_struct) => {
916 let mut template_attr = None;
918 let mut template_file_attr = None;
919 let mut mode_attr = None;
920 let mut validate_attr = false;
921 let mut type_marker_attr = false;
922
923 for attr in &input.attrs {
924 if attr.path().is_ident("prompt") {
925 if let Ok(metas) =
927 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
928 {
929 for meta in metas {
930 match meta {
931 Meta::NameValue(nv) if nv.path.is_ident("template") => {
932 if let syn::Expr::Lit(expr_lit) = nv.value
933 && let syn::Lit::Str(lit_str) = expr_lit.lit
934 {
935 template_attr = Some(lit_str.value());
936 }
937 }
938 Meta::NameValue(nv) if nv.path.is_ident("template_file") => {
939 if let syn::Expr::Lit(expr_lit) = nv.value
940 && let syn::Lit::Str(lit_str) = expr_lit.lit
941 {
942 template_file_attr = Some(lit_str.value());
943 }
944 }
945 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
946 if let syn::Expr::Lit(expr_lit) = nv.value
947 && let syn::Lit::Str(lit_str) = expr_lit.lit
948 {
949 mode_attr = Some(lit_str.value());
950 }
951 }
952 Meta::NameValue(nv) if nv.path.is_ident("validate") => {
953 if let syn::Expr::Lit(expr_lit) = nv.value
954 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
955 {
956 validate_attr = lit_bool.value();
957 }
958 }
959 Meta::NameValue(nv) if nv.path.is_ident("type_marker") => {
960 if let syn::Expr::Lit(expr_lit) = nv.value
961 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
962 {
963 type_marker_attr = lit_bool.value();
964 }
965 }
966 Meta::Path(path) if path.is_ident("type_marker") => {
967 type_marker_attr = true;
969 }
970 _ => {}
971 }
972 }
973 }
974 }
975 }
976
977 if template_attr.is_some() && template_file_attr.is_some() {
979 return syn::Error::new(
980 input.ident.span(),
981 "The `template` and `template_file` attributes are mutually exclusive. Please use only one.",
982 ).to_compile_error().into();
983 }
984
985 let template_str = if let Some(file_path) = template_file_attr {
987 let mut full_path = None;
991
992 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
994 let is_trybuild = manifest_dir.contains("target/tests/trybuild");
996
997 if !is_trybuild {
998 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
1000 if candidate.exists() {
1001 full_path = Some(candidate);
1002 }
1003 } else {
1004 if let Some(target_pos) = manifest_dir.find("/target/tests/trybuild") {
1010 let workspace_root = &manifest_dir[..target_pos];
1011 let original_macros_dir = std::path::Path::new(workspace_root)
1013 .join("crates")
1014 .join("llm-toolkit-macros");
1015
1016 let candidate = original_macros_dir.join(&file_path);
1017 if candidate.exists() {
1018 full_path = Some(candidate);
1019 }
1020 }
1021 }
1022 }
1023
1024 if full_path.is_none() {
1026 let candidate = std::path::Path::new(&file_path).to_path_buf();
1027 if candidate.exists() {
1028 full_path = Some(candidate);
1029 }
1030 }
1031
1032 if full_path.is_none()
1035 && let Ok(current_dir) = std::env::current_dir()
1036 {
1037 let mut search_dir = current_dir.as_path();
1038 for _ in 0..10 {
1040 let macros_dir = search_dir.join("crates/llm-toolkit-macros");
1042 if macros_dir.exists() {
1043 let candidate = macros_dir.join(&file_path);
1044 if candidate.exists() {
1045 full_path = Some(candidate);
1046 break;
1047 }
1048 }
1049 let candidate = search_dir.join(&file_path);
1051 if candidate.exists() {
1052 full_path = Some(candidate);
1053 break;
1054 }
1055 if let Some(parent) = search_dir.parent() {
1056 search_dir = parent;
1057 } else {
1058 break;
1059 }
1060 }
1061 }
1062
1063 if full_path.is_none() {
1065 let mut error_msg = format!(
1067 "Template file '{}' not found at compile time.\n\nSearched in:",
1068 file_path
1069 );
1070
1071 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
1072 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
1073 error_msg.push_str(&format!("\n - {}", candidate.display()));
1074 }
1075
1076 if let Ok(current_dir) = std::env::current_dir() {
1077 let candidate = current_dir.join(&file_path);
1078 error_msg.push_str(&format!("\n - {}", candidate.display()));
1079 }
1080
1081 error_msg.push_str("\n\nPlease ensure:");
1082 error_msg.push_str("\n 1. The template file exists");
1083 error_msg.push_str("\n 2. The path is relative to CARGO_MANIFEST_DIR");
1084 error_msg.push_str("\n 3. There are no typos in the path");
1085
1086 return syn::Error::new(input.ident.span(), error_msg)
1087 .to_compile_error()
1088 .into();
1089 }
1090
1091 let final_path = full_path.unwrap();
1092
1093 match std::fs::read_to_string(&final_path) {
1095 Ok(content) => Some(content),
1096 Err(e) => {
1097 return syn::Error::new(
1098 input.ident.span(),
1099 format!(
1100 "Failed to read template file '{}': {}\n\nPath resolved to: {}",
1101 file_path,
1102 e,
1103 final_path.display()
1104 ),
1105 )
1106 .to_compile_error()
1107 .into();
1108 }
1109 }
1110 } else {
1111 template_attr
1112 };
1113
1114 if validate_attr && let Some(template) = &template_str {
1116 let mut env = minijinja::Environment::new();
1118 if let Err(e) = env.add_template("validation", template) {
1119 let warning_msg =
1121 format!("Template validation warning: Invalid Jinja syntax - {}", e);
1122 let warning_ident = syn::Ident::new(
1123 "TEMPLATE_VALIDATION_WARNING",
1124 proc_macro2::Span::call_site(),
1125 );
1126 let _warning_tokens = quote! {
1127 #[deprecated(note = #warning_msg)]
1128 const #warning_ident: () = ();
1129 let _ = #warning_ident;
1130 };
1131 eprintln!("cargo:warning={}", warning_msg);
1133 }
1134
1135 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1137 &fields.named
1138 } else {
1139 panic!("Template validation is only supported for structs with named fields.");
1140 };
1141
1142 let field_names: std::collections::HashSet<String> = fields
1143 .iter()
1144 .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
1145 .collect();
1146
1147 let placeholders = parse_template_placeholders_with_mode(template);
1149
1150 for (placeholder_name, _mode) in &placeholders {
1151 if placeholder_name != "self" && !field_names.contains(placeholder_name) {
1152 let warning_msg = format!(
1153 "Template validation warning: Variable '{}' used in template but not found in struct fields",
1154 placeholder_name
1155 );
1156 eprintln!("cargo:warning={}", warning_msg);
1157 }
1158 }
1159 }
1160
1161 let name = input.ident;
1162 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1163
1164 let struct_docs = extract_doc_comments(&input.attrs);
1166
1167 let is_mode_based =
1169 mode_attr.is_some() || (template_str.is_none() && struct_docs.contains("mode"));
1170
1171 let expanded = if is_mode_based || mode_attr.is_some() {
1172 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1174 &fields.named
1175 } else {
1176 panic!(
1177 "Mode-based prompt generation is only supported for structs with named fields."
1178 );
1179 };
1180
1181 let struct_name_str = name.to_string();
1182
1183 let has_default = input.attrs.iter().any(|attr| {
1185 if attr.path().is_ident("derive")
1186 && let Ok(meta_list) = attr.meta.require_list()
1187 {
1188 let tokens_str = meta_list.tokens.to_string();
1189 tokens_str.contains("Default")
1190 } else {
1191 false
1192 }
1193 });
1194
1195 let schema_parts = generate_schema_only_parts(
1206 &struct_name_str,
1207 &struct_docs,
1208 fields,
1209 &crate_path,
1210 type_marker_attr,
1211 );
1212
1213 let example_parts = generate_example_only_parts(fields, has_default, &crate_path);
1215
1216 quote! {
1217 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1218 fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<#crate_path::prompt::PromptPart> {
1219 match mode {
1220 "schema_only" => #schema_parts,
1221 "example_only" => #example_parts,
1222 "full" | _ => {
1223 let mut parts = Vec::new();
1225
1226 let schema_parts = #schema_parts;
1228 parts.extend(schema_parts);
1229
1230 parts.push(#crate_path::prompt::PromptPart::Text("\n### Example".to_string()));
1232 parts.push(#crate_path::prompt::PromptPart::Text(
1233 format!("Here is an example of a valid `{}` object:", #struct_name_str)
1234 ));
1235
1236 let example_parts = #example_parts;
1238 parts.extend(example_parts);
1239
1240 parts
1241 }
1242 }
1243 }
1244
1245 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1246 self.to_prompt_parts_with_mode("full")
1247 }
1248
1249 fn to_prompt(&self) -> String {
1250 self.to_prompt_parts()
1251 .into_iter()
1252 .filter_map(|part| match part {
1253 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1254 _ => None,
1255 })
1256 .collect::<Vec<_>>()
1257 .join("\n")
1258 }
1259
1260 fn prompt_schema() -> String {
1261 use std::sync::OnceLock;
1262 static SCHEMA_CACHE: OnceLock<String> = OnceLock::new();
1263
1264 SCHEMA_CACHE.get_or_init(|| {
1265 let schema_parts = #schema_parts;
1266 schema_parts
1267 .into_iter()
1268 .filter_map(|part| match part {
1269 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1270 _ => None,
1271 })
1272 .collect::<Vec<_>>()
1273 .join("\n")
1274 }).clone()
1275 }
1276 }
1277 }
1278 } else if let Some(template) = template_str {
1279 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1282 &fields.named
1283 } else {
1284 panic!(
1285 "Template prompt generation is only supported for structs with named fields."
1286 );
1287 };
1288
1289 let placeholders = parse_template_placeholders_with_mode(&template);
1291 let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
1293 mode.is_some()
1294 && fields
1295 .iter()
1296 .any(|f| f.ident.as_ref().unwrap() == field_name)
1297 });
1298
1299 let mut image_field_parts = Vec::new();
1300 for f in fields.iter() {
1301 let field_name = f.ident.as_ref().unwrap();
1302 let attrs = parse_field_prompt_attrs(&f.attrs);
1303
1304 if attrs.image {
1305 image_field_parts.push(quote! {
1307 parts.extend(self.#field_name.to_prompt_parts());
1308 });
1309 }
1310 }
1311
1312 if has_mode_syntax {
1314 let mut context_fields = Vec::new();
1316 let mut modified_template = template.clone();
1317
1318 for (field_name, mode_opt) in &placeholders {
1320 if let Some(mode) = mode_opt {
1321 let unique_key = format!("{}__{}", field_name, mode);
1323
1324 let pattern = format!("{{{{ {}:{} }}}}", field_name, mode);
1326 let replacement = format!("{{{{ {} }}}}", unique_key);
1327 modified_template = modified_template.replace(&pattern, &replacement);
1328
1329 let field_ident =
1331 syn::Ident::new(field_name, proc_macro2::Span::call_site());
1332
1333 context_fields.push(quote! {
1335 context.insert(
1336 #unique_key.to_string(),
1337 minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
1338 );
1339 });
1340 }
1341 }
1342
1343 for field in fields.iter() {
1345 let field_name = field.ident.as_ref().unwrap();
1346 let field_name_str = field_name.to_string();
1347
1348 let has_mode_entry = placeholders
1350 .iter()
1351 .any(|(name, mode)| name == &field_name_str && mode.is_some());
1352
1353 if !has_mode_entry {
1354 let is_primitive = match &field.ty {
1357 syn::Type::Path(type_path) => {
1358 if let Some(segment) = type_path.path.segments.last() {
1359 let type_name = segment.ident.to_string();
1360 matches!(
1361 type_name.as_str(),
1362 "String"
1363 | "str"
1364 | "i8"
1365 | "i16"
1366 | "i32"
1367 | "i64"
1368 | "i128"
1369 | "isize"
1370 | "u8"
1371 | "u16"
1372 | "u32"
1373 | "u64"
1374 | "u128"
1375 | "usize"
1376 | "f32"
1377 | "f64"
1378 | "bool"
1379 | "char"
1380 )
1381 } else {
1382 false
1383 }
1384 }
1385 _ => false,
1386 };
1387
1388 if is_primitive {
1389 context_fields.push(quote! {
1390 context.insert(
1391 #field_name_str.to_string(),
1392 minijinja::Value::from_serialize(&self.#field_name)
1393 );
1394 });
1395 } else {
1396 context_fields.push(quote! {
1398 context.insert(
1399 #field_name_str.to_string(),
1400 minijinja::Value::from(self.#field_name.to_prompt())
1401 );
1402 });
1403 }
1404 }
1405 }
1406
1407 quote! {
1408 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1409 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1410 let mut parts = Vec::new();
1411
1412 #(#image_field_parts)*
1414
1415 let text = {
1417 let mut env = minijinja::Environment::new();
1418 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1419 panic!("Failed to parse template: {}", e)
1420 });
1421
1422 let tmpl = env.get_template("prompt").unwrap();
1423
1424 let mut context = std::collections::HashMap::new();
1425 #(#context_fields)*
1426
1427 tmpl.render(context).unwrap_or_else(|e| {
1428 format!("Failed to render prompt: {}", e)
1429 })
1430 };
1431
1432 if !text.is_empty() {
1433 parts.push(#crate_path::prompt::PromptPart::Text(text));
1434 }
1435
1436 parts
1437 }
1438
1439 fn to_prompt(&self) -> String {
1440 let mut env = minijinja::Environment::new();
1442 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1443 panic!("Failed to parse template: {}", e)
1444 });
1445
1446 let tmpl = env.get_template("prompt").unwrap();
1447
1448 let mut context = std::collections::HashMap::new();
1449 #(#context_fields)*
1450
1451 tmpl.render(context).unwrap_or_else(|e| {
1452 format!("Failed to render prompt: {}", e)
1453 })
1454 }
1455
1456 fn prompt_schema() -> String {
1457 String::new() }
1459 }
1460 }
1461 } else {
1462 quote! {
1464 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1465 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1466 let mut parts = Vec::new();
1467
1468 #(#image_field_parts)*
1470
1471 let text = #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1473 format!("Failed to render prompt: {}", e)
1474 });
1475 if !text.is_empty() {
1476 parts.push(#crate_path::prompt::PromptPart::Text(text));
1477 }
1478
1479 parts
1480 }
1481
1482 fn to_prompt(&self) -> String {
1483 #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1484 format!("Failed to render prompt: {}", e)
1485 })
1486 }
1487
1488 fn prompt_schema() -> String {
1489 String::new() }
1491 }
1492 }
1493 }
1494 } else {
1495 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1498 &fields.named
1499 } else {
1500 panic!(
1501 "Default prompt generation is only supported for structs with named fields."
1502 );
1503 };
1504
1505 let mut text_field_parts = Vec::new();
1507 let mut image_field_parts = Vec::new();
1508
1509 for f in fields.iter() {
1510 let field_name = f.ident.as_ref().unwrap();
1511 let attrs = parse_field_prompt_attrs(&f.attrs);
1512
1513 if attrs.skip {
1515 continue;
1516 }
1517
1518 if attrs.image {
1519 image_field_parts.push(quote! {
1521 parts.extend(self.#field_name.to_prompt_parts());
1522 });
1523 } else {
1524 let key = if let Some(rename) = attrs.rename {
1530 rename
1531 } else {
1532 let doc_comment = extract_doc_comments(&f.attrs);
1533 if !doc_comment.is_empty() {
1534 doc_comment
1535 } else {
1536 field_name.to_string()
1537 }
1538 };
1539
1540 let value_expr = if let Some(format_with) = attrs.format_with {
1542 let func_path: syn::Path =
1544 syn::parse_str(&format_with).unwrap_or_else(|_| {
1545 panic!("Invalid function path: {}", format_with)
1546 });
1547 quote! { #func_path(&self.#field_name) }
1548 } else {
1549 quote! { self.#field_name.to_prompt() }
1550 };
1551
1552 text_field_parts.push(quote! {
1553 text_parts.push(format!("{}: {}", #key, #value_expr));
1554 });
1555 }
1556 }
1557
1558 quote! {
1560 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1561 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1562 let mut parts = Vec::new();
1563
1564 #(#image_field_parts)*
1566
1567 let mut text_parts = Vec::new();
1569 #(#text_field_parts)*
1570
1571 if !text_parts.is_empty() {
1572 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1573 }
1574
1575 parts
1576 }
1577
1578 fn to_prompt(&self) -> String {
1579 let mut text_parts = Vec::new();
1580 #(#text_field_parts)*
1581 text_parts.join("\n")
1582 }
1583
1584 fn prompt_schema() -> String {
1585 String::new() }
1587 }
1588 }
1589 };
1590
1591 TokenStream::from(expanded)
1592 }
1593 Data::Union(_) => {
1594 panic!("`#[derive(ToPrompt)]` is not supported for unions");
1595 }
1596 }
1597}
1598
1599#[derive(Debug, Clone)]
1601struct TargetInfo {
1602 name: String,
1603 template: Option<String>,
1604 field_configs: std::collections::HashMap<String, FieldTargetConfig>,
1605}
1606
1607#[derive(Debug, Clone, Default)]
1609struct FieldTargetConfig {
1610 skip: bool,
1611 rename: Option<String>,
1612 format_with: Option<String>,
1613 image: bool,
1614 include_only: bool, }
1616
1617fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
1619 let mut configs = Vec::new();
1620
1621 for attr in attrs {
1622 if attr.path().is_ident("prompt_for")
1623 && let Ok(meta_list) = attr.meta.require_list()
1624 {
1625 if meta_list.tokens.to_string() == "skip" {
1627 let config = FieldTargetConfig {
1629 skip: true,
1630 ..Default::default()
1631 };
1632 configs.push(("*".to_string(), config));
1633 } else if let Ok(metas) =
1634 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1635 {
1636 let mut target_name = None;
1637 let mut config = FieldTargetConfig::default();
1638
1639 for meta in metas {
1640 match meta {
1641 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1642 if let syn::Expr::Lit(syn::ExprLit {
1643 lit: syn::Lit::Str(lit_str),
1644 ..
1645 }) = nv.value
1646 {
1647 target_name = Some(lit_str.value());
1648 }
1649 }
1650 Meta::Path(path) if path.is_ident("skip") => {
1651 config.skip = true;
1652 }
1653 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
1654 if let syn::Expr::Lit(syn::ExprLit {
1655 lit: syn::Lit::Str(lit_str),
1656 ..
1657 }) = nv.value
1658 {
1659 config.rename = Some(lit_str.value());
1660 }
1661 }
1662 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
1663 if let syn::Expr::Lit(syn::ExprLit {
1664 lit: syn::Lit::Str(lit_str),
1665 ..
1666 }) = nv.value
1667 {
1668 config.format_with = Some(lit_str.value());
1669 }
1670 }
1671 Meta::Path(path) if path.is_ident("image") => {
1672 config.image = true;
1673 }
1674 _ => {}
1675 }
1676 }
1677
1678 if let Some(name) = target_name {
1679 config.include_only = true;
1680 configs.push((name, config));
1681 }
1682 }
1683 }
1684 }
1685
1686 configs
1687}
1688
1689fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
1691 let mut targets = Vec::new();
1692
1693 for attr in attrs {
1694 if attr.path().is_ident("prompt_for")
1695 && let Ok(meta_list) = attr.meta.require_list()
1696 && let Ok(metas) =
1697 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1698 {
1699 let mut target_name = None;
1700 let mut template = None;
1701
1702 for meta in metas {
1703 match meta {
1704 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1705 if let syn::Expr::Lit(syn::ExprLit {
1706 lit: syn::Lit::Str(lit_str),
1707 ..
1708 }) = nv.value
1709 {
1710 target_name = Some(lit_str.value());
1711 }
1712 }
1713 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1714 if let syn::Expr::Lit(syn::ExprLit {
1715 lit: syn::Lit::Str(lit_str),
1716 ..
1717 }) = nv.value
1718 {
1719 template = Some(lit_str.value());
1720 }
1721 }
1722 _ => {}
1723 }
1724 }
1725
1726 if let Some(name) = target_name {
1727 targets.push(TargetInfo {
1728 name,
1729 template,
1730 field_configs: std::collections::HashMap::new(),
1731 });
1732 }
1733 }
1734 }
1735
1736 targets
1737}
1738
1739#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
1740pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
1741 let input = parse_macro_input!(input as DeriveInput);
1742
1743 let found_crate =
1744 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1745 let crate_path = match found_crate {
1746 FoundCrate::Itself => {
1747 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1749 quote!(::#ident)
1750 }
1751 FoundCrate::Name(name) => {
1752 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1753 quote!(::#ident)
1754 }
1755 };
1756
1757 let data_struct = match &input.data {
1759 Data::Struct(data) => data,
1760 _ => {
1761 return syn::Error::new(
1762 input.ident.span(),
1763 "`#[derive(ToPromptSet)]` is only supported for structs",
1764 )
1765 .to_compile_error()
1766 .into();
1767 }
1768 };
1769
1770 let fields = match &data_struct.fields {
1771 syn::Fields::Named(fields) => &fields.named,
1772 _ => {
1773 return syn::Error::new(
1774 input.ident.span(),
1775 "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
1776 )
1777 .to_compile_error()
1778 .into();
1779 }
1780 };
1781
1782 let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
1784
1785 for field in fields.iter() {
1787 let field_name = field.ident.as_ref().unwrap().to_string();
1788 let field_configs = parse_prompt_for_attrs(&field.attrs);
1789
1790 for (target_name, config) in field_configs {
1791 if target_name == "*" {
1792 for target in &mut targets {
1794 target
1795 .field_configs
1796 .entry(field_name.clone())
1797 .or_insert_with(FieldTargetConfig::default)
1798 .skip = config.skip;
1799 }
1800 } else {
1801 let target_exists = targets.iter().any(|t| t.name == target_name);
1803 if !target_exists {
1804 targets.push(TargetInfo {
1806 name: target_name.clone(),
1807 template: None,
1808 field_configs: std::collections::HashMap::new(),
1809 });
1810 }
1811
1812 let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
1813
1814 target.field_configs.insert(field_name.clone(), config);
1815 }
1816 }
1817 }
1818
1819 let mut match_arms = Vec::new();
1821
1822 for target in &targets {
1823 let target_name = &target.name;
1824
1825 if let Some(template_str) = &target.template {
1826 let mut image_parts = Vec::new();
1828
1829 for field in fields.iter() {
1830 let field_name = field.ident.as_ref().unwrap();
1831 let field_name_str = field_name.to_string();
1832
1833 if let Some(config) = target.field_configs.get(&field_name_str)
1834 && config.image
1835 {
1836 image_parts.push(quote! {
1837 parts.extend(self.#field_name.to_prompt_parts());
1838 });
1839 }
1840 }
1841
1842 match_arms.push(quote! {
1843 #target_name => {
1844 let mut parts = Vec::new();
1845
1846 #(#image_parts)*
1847
1848 let text = #crate_path::prompt::render_prompt(#template_str, self)
1849 .map_err(|e| #crate_path::prompt::PromptSetError::RenderFailed {
1850 target: #target_name.to_string(),
1851 source: e,
1852 })?;
1853
1854 if !text.is_empty() {
1855 parts.push(#crate_path::prompt::PromptPart::Text(text));
1856 }
1857
1858 Ok(parts)
1859 }
1860 });
1861 } else {
1862 let mut text_field_parts = Vec::new();
1864 let mut image_field_parts = Vec::new();
1865
1866 for field in fields.iter() {
1867 let field_name = field.ident.as_ref().unwrap();
1868 let field_name_str = field_name.to_string();
1869
1870 let config = target.field_configs.get(&field_name_str);
1872
1873 if let Some(cfg) = config
1875 && cfg.skip
1876 {
1877 continue;
1878 }
1879
1880 let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
1884 let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
1885 .iter()
1886 .any(|(name, _)| name != "*");
1887
1888 if has_any_target_specific_config && !is_explicitly_for_this_target {
1889 continue;
1890 }
1891
1892 if let Some(cfg) = config {
1893 if cfg.image {
1894 image_field_parts.push(quote! {
1895 parts.extend(self.#field_name.to_prompt_parts());
1896 });
1897 } else {
1898 let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
1899
1900 let value_expr = if let Some(format_with) = &cfg.format_with {
1901 match syn::parse_str::<syn::Path>(format_with) {
1903 Ok(func_path) => quote! { #func_path(&self.#field_name) },
1904 Err(_) => {
1905 let error_msg = format!(
1907 "Invalid function path in format_with: '{}'",
1908 format_with
1909 );
1910 quote! {
1911 compile_error!(#error_msg);
1912 String::new()
1913 }
1914 }
1915 }
1916 } else {
1917 quote! { self.#field_name.to_prompt() }
1918 };
1919
1920 text_field_parts.push(quote! {
1921 text_parts.push(format!("{}: {}", #key, #value_expr));
1922 });
1923 }
1924 } else {
1925 text_field_parts.push(quote! {
1927 text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
1928 });
1929 }
1930 }
1931
1932 match_arms.push(quote! {
1933 #target_name => {
1934 let mut parts = Vec::new();
1935
1936 #(#image_field_parts)*
1937
1938 let mut text_parts = Vec::new();
1939 #(#text_field_parts)*
1940
1941 if !text_parts.is_empty() {
1942 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1943 }
1944
1945 Ok(parts)
1946 }
1947 });
1948 }
1949 }
1950
1951 let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
1953
1954 match_arms.push(quote! {
1956 _ => {
1957 let available = vec![#(#target_names.to_string()),*];
1958 Err(#crate_path::prompt::PromptSetError::TargetNotFound {
1959 target: target.to_string(),
1960 available,
1961 })
1962 }
1963 });
1964
1965 let struct_name = &input.ident;
1966 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1967
1968 let expanded = quote! {
1969 impl #impl_generics #crate_path::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
1970 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<#crate_path::prompt::PromptPart>, #crate_path::prompt::PromptSetError> {
1971 match target {
1972 #(#match_arms)*
1973 }
1974 }
1975 }
1976 };
1977
1978 TokenStream::from(expanded)
1979}
1980
1981struct TypeList {
1983 types: Punctuated<syn::Type, Token![,]>,
1984}
1985
1986impl Parse for TypeList {
1987 fn parse(input: ParseStream) -> syn::Result<Self> {
1988 Ok(TypeList {
1989 types: Punctuated::parse_terminated(input)?,
1990 })
1991 }
1992}
1993
1994#[proc_macro]
2018pub fn examples_section(input: TokenStream) -> TokenStream {
2019 let input = parse_macro_input!(input as TypeList);
2020
2021 let found_crate =
2022 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2023 let _crate_path = match found_crate {
2024 FoundCrate::Itself => quote!(crate),
2025 FoundCrate::Name(name) => {
2026 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2027 quote!(::#ident)
2028 }
2029 };
2030
2031 let mut type_sections = Vec::new();
2033
2034 for ty in input.types.iter() {
2035 let type_name_str = quote!(#ty).to_string();
2037
2038 type_sections.push(quote! {
2040 {
2041 let type_name = #type_name_str;
2042 let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
2043 format!("---\n#### `{}`\n{}", type_name, json_example)
2044 }
2045 });
2046 }
2047
2048 let expanded = quote! {
2050 {
2051 let mut sections = Vec::new();
2052 sections.push("---".to_string());
2053 sections.push("### Examples".to_string());
2054 sections.push("".to_string());
2055 sections.push("Here are examples of the data structures you should use.".to_string());
2056 sections.push("".to_string());
2057
2058 #(sections.push(#type_sections);)*
2059
2060 sections.push("---".to_string());
2061
2062 sections.join("\n")
2063 }
2064 };
2065
2066 TokenStream::from(expanded)
2067}
2068
2069fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
2071 for attr in attrs {
2072 if attr.path().is_ident("prompt_for")
2073 && let Ok(meta_list) = attr.meta.require_list()
2074 && let Ok(metas) =
2075 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2076 {
2077 let mut target_type = None;
2078 let mut template = None;
2079
2080 for meta in metas {
2081 match meta {
2082 Meta::NameValue(nv) if nv.path.is_ident("target") => {
2083 if let syn::Expr::Lit(syn::ExprLit {
2084 lit: syn::Lit::Str(lit_str),
2085 ..
2086 }) = nv.value
2087 {
2088 target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
2090 }
2091 }
2092 Meta::NameValue(nv) if nv.path.is_ident("template") => {
2093 if let syn::Expr::Lit(syn::ExprLit {
2094 lit: syn::Lit::Str(lit_str),
2095 ..
2096 }) = nv.value
2097 {
2098 template = Some(lit_str.value());
2099 }
2100 }
2101 _ => {}
2102 }
2103 }
2104
2105 if let (Some(target), Some(tmpl)) = (target_type, template) {
2106 return (target, tmpl);
2107 }
2108 }
2109 }
2110
2111 panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
2112}
2113
2114#[proc_macro_attribute]
2148pub fn define_intent(_attr: TokenStream, item: TokenStream) -> TokenStream {
2149 let input = parse_macro_input!(item as DeriveInput);
2150
2151 let found_crate =
2152 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2153 let crate_path = match found_crate {
2154 FoundCrate::Itself => {
2155 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2157 quote!(::#ident)
2158 }
2159 FoundCrate::Name(name) => {
2160 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2161 quote!(::#ident)
2162 }
2163 };
2164
2165 let enum_data = match &input.data {
2167 Data::Enum(data) => data,
2168 _ => {
2169 return syn::Error::new(
2170 input.ident.span(),
2171 "`#[define_intent]` can only be applied to enums",
2172 )
2173 .to_compile_error()
2174 .into();
2175 }
2176 };
2177
2178 let mut prompt_template = None;
2180 let mut extractor_tag = None;
2181 let mut mode = None;
2182
2183 for attr in &input.attrs {
2184 if attr.path().is_ident("intent")
2185 && let Ok(metas) =
2186 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2187 {
2188 for meta in metas {
2189 match meta {
2190 Meta::NameValue(nv) if nv.path.is_ident("prompt") => {
2191 if let syn::Expr::Lit(syn::ExprLit {
2192 lit: syn::Lit::Str(lit_str),
2193 ..
2194 }) = nv.value
2195 {
2196 prompt_template = Some(lit_str.value());
2197 }
2198 }
2199 Meta::NameValue(nv) if nv.path.is_ident("extractor_tag") => {
2200 if let syn::Expr::Lit(syn::ExprLit {
2201 lit: syn::Lit::Str(lit_str),
2202 ..
2203 }) = nv.value
2204 {
2205 extractor_tag = Some(lit_str.value());
2206 }
2207 }
2208 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
2209 if let syn::Expr::Lit(syn::ExprLit {
2210 lit: syn::Lit::Str(lit_str),
2211 ..
2212 }) = nv.value
2213 {
2214 mode = Some(lit_str.value());
2215 }
2216 }
2217 _ => {}
2218 }
2219 }
2220 }
2221 }
2222
2223 let mode = mode.unwrap_or_else(|| "single".to_string());
2225
2226 if mode != "single" && mode != "multi_tag" {
2228 return syn::Error::new(
2229 input.ident.span(),
2230 "`mode` must be either \"single\" or \"multi_tag\"",
2231 )
2232 .to_compile_error()
2233 .into();
2234 }
2235
2236 let prompt_template = match prompt_template {
2238 Some(p) => p,
2239 None => {
2240 return syn::Error::new(
2241 input.ident.span(),
2242 "`#[intent(...)]` attribute must include `prompt = \"...\"`",
2243 )
2244 .to_compile_error()
2245 .into();
2246 }
2247 };
2248
2249 if mode == "multi_tag" {
2251 let enum_name = &input.ident;
2252 let actions_doc = generate_multi_tag_actions_doc(&enum_data.variants);
2253 return generate_multi_tag_output(
2254 &input,
2255 enum_name,
2256 enum_data,
2257 prompt_template,
2258 actions_doc,
2259 );
2260 }
2261
2262 let extractor_tag = match extractor_tag {
2264 Some(t) => t,
2265 None => {
2266 return syn::Error::new(
2267 input.ident.span(),
2268 "`#[intent(...)]` attribute must include `extractor_tag = \"...\"`",
2269 )
2270 .to_compile_error()
2271 .into();
2272 }
2273 };
2274
2275 let enum_name = &input.ident;
2277 let enum_docs = extract_doc_comments(&input.attrs);
2278
2279 let mut intents_doc_lines = Vec::new();
2280
2281 if !enum_docs.is_empty() {
2283 intents_doc_lines.push(format!("{}: {}", enum_name, enum_docs));
2284 } else {
2285 intents_doc_lines.push(format!("{}:", enum_name));
2286 }
2287 intents_doc_lines.push(String::new()); intents_doc_lines.push("Possible values:".to_string());
2289
2290 for variant in &enum_data.variants {
2292 let variant_name = &variant.ident;
2293 let variant_docs = extract_doc_comments(&variant.attrs);
2294
2295 if !variant_docs.is_empty() {
2296 intents_doc_lines.push(format!("- {}: {}", variant_name, variant_docs));
2297 } else {
2298 intents_doc_lines.push(format!("- {}", variant_name));
2299 }
2300 }
2301
2302 let intents_doc_str = intents_doc_lines.join("\n");
2303
2304 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2306 let user_variables: Vec<String> = placeholders
2307 .iter()
2308 .filter_map(|(name, _)| {
2309 if name != "intents_doc" {
2310 Some(name.clone())
2311 } else {
2312 None
2313 }
2314 })
2315 .collect();
2316
2317 let enum_name_str = enum_name.to_string();
2319 let snake_case_name = to_snake_case(&enum_name_str);
2320 let function_name = syn::Ident::new(
2321 &format!("build_{}_prompt", snake_case_name),
2322 proc_macro2::Span::call_site(),
2323 );
2324
2325 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2327 .iter()
2328 .map(|var| {
2329 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2330 quote! { #ident: &str }
2331 })
2332 .collect();
2333
2334 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2336 .iter()
2337 .map(|var| {
2338 let var_str = var.clone();
2339 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2340 quote! {
2341 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2342 }
2343 })
2344 .collect();
2345
2346 let converted_template = prompt_template.clone();
2348
2349 let extractor_name = syn::Ident::new(
2351 &format!("{}Extractor", enum_name),
2352 proc_macro2::Span::call_site(),
2353 );
2354
2355 let filtered_attrs: Vec<_> = input
2357 .attrs
2358 .iter()
2359 .filter(|attr| !attr.path().is_ident("intent"))
2360 .collect();
2361
2362 let vis = &input.vis;
2364 let generics = &input.generics;
2365 let variants = &enum_data.variants;
2366 let enum_output = quote! {
2367 #(#filtered_attrs)*
2368 #vis enum #enum_name #generics {
2369 #variants
2370 }
2371 };
2372
2373 let expanded = quote! {
2375 #enum_output
2377
2378 pub fn #function_name(#(#function_params),*) -> String {
2380 let mut env = minijinja::Environment::new();
2381 env.add_template("prompt", #converted_template)
2382 .expect("Failed to parse intent prompt template");
2383
2384 let tmpl = env.get_template("prompt").unwrap();
2385
2386 let mut __template_context = std::collections::HashMap::new();
2387
2388 __template_context.insert("intents_doc".to_string(), minijinja::Value::from(#intents_doc_str));
2390
2391 #(#context_insertions)*
2393
2394 tmpl.render(&__template_context)
2395 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2396 }
2397
2398 pub struct #extractor_name;
2400
2401 impl #extractor_name {
2402 pub const EXTRACTOR_TAG: &'static str = #extractor_tag;
2403 }
2404
2405 impl #crate_path::intent::IntentExtractor<#enum_name> for #extractor_name {
2406 fn extract_intent(&self, response: &str) -> Result<#enum_name, #crate_path::intent::IntentExtractionError> {
2407 #crate_path::intent::extract_intent_from_response(response, Self::EXTRACTOR_TAG)
2409 }
2410 }
2411 };
2412
2413 TokenStream::from(expanded)
2414}
2415
2416fn to_snake_case(s: &str) -> String {
2418 let mut result = String::new();
2419 let mut prev_upper = false;
2420
2421 for (i, ch) in s.chars().enumerate() {
2422 if ch.is_uppercase() {
2423 if i > 0 && !prev_upper {
2424 result.push('_');
2425 }
2426 result.push(ch.to_lowercase().next().unwrap());
2427 prev_upper = true;
2428 } else {
2429 result.push(ch);
2430 prev_upper = false;
2431 }
2432 }
2433
2434 result
2435}
2436
2437#[derive(Debug, Default)]
2439struct ActionAttrs {
2440 tag: Option<String>,
2441}
2442
2443fn parse_action_attrs(attrs: &[syn::Attribute]) -> ActionAttrs {
2444 let mut result = ActionAttrs::default();
2445
2446 for attr in attrs {
2447 if attr.path().is_ident("action")
2448 && let Ok(meta_list) = attr.meta.require_list()
2449 && let Ok(metas) =
2450 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2451 {
2452 for meta in metas {
2453 if let Meta::NameValue(nv) = meta
2454 && nv.path.is_ident("tag")
2455 && let syn::Expr::Lit(syn::ExprLit {
2456 lit: syn::Lit::Str(lit_str),
2457 ..
2458 }) = nv.value
2459 {
2460 result.tag = Some(lit_str.value());
2461 }
2462 }
2463 }
2464 }
2465
2466 result
2467}
2468
2469#[derive(Debug, Default)]
2471struct FieldActionAttrs {
2472 is_attribute: bool,
2473 is_inner_text: bool,
2474}
2475
2476fn parse_field_action_attrs(attrs: &[syn::Attribute]) -> FieldActionAttrs {
2477 let mut result = FieldActionAttrs::default();
2478
2479 for attr in attrs {
2480 if attr.path().is_ident("action")
2481 && let Ok(meta_list) = attr.meta.require_list()
2482 {
2483 let tokens_str = meta_list.tokens.to_string();
2484 if tokens_str == "attribute" {
2485 result.is_attribute = true;
2486 } else if tokens_str == "inner_text" {
2487 result.is_inner_text = true;
2488 }
2489 }
2490 }
2491
2492 result
2493}
2494
2495fn generate_multi_tag_actions_doc(
2497 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2498) -> String {
2499 let mut doc_lines = Vec::new();
2500
2501 for variant in variants {
2502 let action_attrs = parse_action_attrs(&variant.attrs);
2503
2504 if let Some(tag) = action_attrs.tag {
2505 let variant_docs = extract_doc_comments(&variant.attrs);
2506
2507 match &variant.fields {
2508 syn::Fields::Unit => {
2509 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2511 }
2512 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
2513 doc_lines.push(format!("- `<{}>...</{}>`: {}", tag, tag, variant_docs));
2515 }
2516 syn::Fields::Named(fields) => {
2517 let mut attrs_str = Vec::new();
2519 let mut has_inner_text = false;
2520
2521 for field in &fields.named {
2522 let field_name = field.ident.as_ref().unwrap();
2523 let field_attrs = parse_field_action_attrs(&field.attrs);
2524
2525 if field_attrs.is_attribute {
2526 attrs_str.push(format!("{}=\"...\"", field_name));
2527 } else if field_attrs.is_inner_text {
2528 has_inner_text = true;
2529 }
2530 }
2531
2532 let attrs_part = if !attrs_str.is_empty() {
2533 format!(" {}", attrs_str.join(" "))
2534 } else {
2535 String::new()
2536 };
2537
2538 if has_inner_text {
2539 doc_lines.push(format!(
2540 "- `<{}{}>...</{}>`: {}",
2541 tag, attrs_part, tag, variant_docs
2542 ));
2543 } else if !attrs_str.is_empty() {
2544 doc_lines.push(format!("- `<{}{} />`: {}", tag, attrs_part, variant_docs));
2545 } else {
2546 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2547 }
2548
2549 for field in &fields.named {
2551 let field_name = field.ident.as_ref().unwrap();
2552 let field_attrs = parse_field_action_attrs(&field.attrs);
2553 let field_docs = extract_doc_comments(&field.attrs);
2554
2555 if field_attrs.is_attribute {
2556 doc_lines
2557 .push(format!(" - `{}` (attribute): {}", field_name, field_docs));
2558 } else if field_attrs.is_inner_text {
2559 doc_lines
2560 .push(format!(" - `{}` (inner_text): {}", field_name, field_docs));
2561 }
2562 }
2563 }
2564 _ => {
2565 }
2567 }
2568 }
2569 }
2570
2571 doc_lines.join("\n")
2572}
2573
2574fn generate_tags_regex(
2576 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2577) -> String {
2578 let mut tag_names = Vec::new();
2579
2580 for variant in variants {
2581 let action_attrs = parse_action_attrs(&variant.attrs);
2582 if let Some(tag) = action_attrs.tag {
2583 tag_names.push(tag);
2584 }
2585 }
2586
2587 if tag_names.is_empty() {
2588 return String::new();
2589 }
2590
2591 let tags_pattern = tag_names.join("|");
2592 format!(
2595 r"(?is)<(?:{})\b[^>]*/>|<(?:{})\b[^>]*>.*?</(?:{})>",
2596 tags_pattern, tags_pattern, tags_pattern
2597 )
2598}
2599
2600fn generate_multi_tag_output(
2602 input: &DeriveInput,
2603 enum_name: &syn::Ident,
2604 enum_data: &syn::DataEnum,
2605 prompt_template: String,
2606 actions_doc: String,
2607) -> TokenStream {
2608 let found_crate =
2609 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2610 let crate_path = match found_crate {
2611 FoundCrate::Itself => {
2612 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2614 quote!(::#ident)
2615 }
2616 FoundCrate::Name(name) => {
2617 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2618 quote!(::#ident)
2619 }
2620 };
2621
2622 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2624 let user_variables: Vec<String> = placeholders
2625 .iter()
2626 .filter_map(|(name, _)| {
2627 if name != "actions_doc" {
2628 Some(name.clone())
2629 } else {
2630 None
2631 }
2632 })
2633 .collect();
2634
2635 let enum_name_str = enum_name.to_string();
2637 let snake_case_name = to_snake_case(&enum_name_str);
2638 let function_name = syn::Ident::new(
2639 &format!("build_{}_prompt", snake_case_name),
2640 proc_macro2::Span::call_site(),
2641 );
2642
2643 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2645 .iter()
2646 .map(|var| {
2647 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2648 quote! { #ident: &str }
2649 })
2650 .collect();
2651
2652 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2654 .iter()
2655 .map(|var| {
2656 let var_str = var.clone();
2657 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2658 quote! {
2659 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2660 }
2661 })
2662 .collect();
2663
2664 let extractor_name = syn::Ident::new(
2666 &format!("{}Extractor", enum_name),
2667 proc_macro2::Span::call_site(),
2668 );
2669
2670 let filtered_attrs: Vec<_> = input
2672 .attrs
2673 .iter()
2674 .filter(|attr| !attr.path().is_ident("intent"))
2675 .collect();
2676
2677 let filtered_variants: Vec<proc_macro2::TokenStream> = enum_data
2679 .variants
2680 .iter()
2681 .map(|variant| {
2682 let variant_name = &variant.ident;
2683 let variant_attrs: Vec<_> = variant
2684 .attrs
2685 .iter()
2686 .filter(|attr| !attr.path().is_ident("action"))
2687 .collect();
2688 let fields = &variant.fields;
2689
2690 let filtered_fields = match fields {
2692 syn::Fields::Named(named_fields) => {
2693 let filtered: Vec<_> = named_fields
2694 .named
2695 .iter()
2696 .map(|field| {
2697 let field_name = &field.ident;
2698 let field_type = &field.ty;
2699 let field_vis = &field.vis;
2700 let filtered_attrs: Vec<_> = field
2701 .attrs
2702 .iter()
2703 .filter(|attr| !attr.path().is_ident("action"))
2704 .collect();
2705 quote! {
2706 #(#filtered_attrs)*
2707 #field_vis #field_name: #field_type
2708 }
2709 })
2710 .collect();
2711 quote! { { #(#filtered,)* } }
2712 }
2713 syn::Fields::Unnamed(unnamed_fields) => {
2714 let types: Vec<_> = unnamed_fields
2715 .unnamed
2716 .iter()
2717 .map(|field| {
2718 let field_type = &field.ty;
2719 quote! { #field_type }
2720 })
2721 .collect();
2722 quote! { (#(#types),*) }
2723 }
2724 syn::Fields::Unit => quote! {},
2725 };
2726
2727 quote! {
2728 #(#variant_attrs)*
2729 #variant_name #filtered_fields
2730 }
2731 })
2732 .collect();
2733
2734 let vis = &input.vis;
2735 let generics = &input.generics;
2736
2737 let parsing_arms = generate_parsing_arms(&enum_data.variants, enum_name);
2739
2740 let tags_regex = generate_tags_regex(&enum_data.variants);
2742
2743 let expanded = quote! {
2744 #(#filtered_attrs)*
2746 #vis enum #enum_name #generics {
2747 #(#filtered_variants),*
2748 }
2749
2750 pub fn #function_name(#(#function_params),*) -> String {
2752 let mut env = minijinja::Environment::new();
2753 env.add_template("prompt", #prompt_template)
2754 .expect("Failed to parse intent prompt template");
2755
2756 let tmpl = env.get_template("prompt").unwrap();
2757
2758 let mut __template_context = std::collections::HashMap::new();
2759
2760 __template_context.insert("actions_doc".to_string(), minijinja::Value::from(#actions_doc));
2762
2763 #(#context_insertions)*
2765
2766 tmpl.render(&__template_context)
2767 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2768 }
2769
2770 pub struct #extractor_name;
2772
2773 impl #extractor_name {
2774 fn parse_single_action(&self, text: &str) -> Option<#enum_name> {
2775 use ::quick_xml::events::Event;
2776 use ::quick_xml::Reader;
2777
2778 let mut actions = Vec::new();
2779 let mut reader = Reader::from_str(text);
2780 reader.config_mut().trim_text(true);
2781
2782 let mut buf = Vec::new();
2783
2784 loop {
2785 match reader.read_event_into(&mut buf) {
2786 Ok(Event::Start(e)) => {
2787 let owned_e = e.into_owned();
2788 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2789 let is_empty = false;
2790
2791 #parsing_arms
2792 }
2793 Ok(Event::Empty(e)) => {
2794 let owned_e = e.into_owned();
2795 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2796 let is_empty = true;
2797
2798 #parsing_arms
2799 }
2800 Ok(Event::Eof) => break,
2801 Err(_) => {
2802 break;
2804 }
2805 _ => {}
2806 }
2807 buf.clear();
2808 }
2809
2810 actions.into_iter().next()
2811 }
2812
2813 pub fn extract_actions(&self, text: &str) -> Result<Vec<#enum_name>, #crate_path::intent::IntentError> {
2814 use ::quick_xml::events::Event;
2815 use ::quick_xml::Reader;
2816
2817 let mut actions = Vec::new();
2818 let mut reader = Reader::from_str(text);
2819 reader.config_mut().trim_text(true);
2820
2821 let mut buf = Vec::new();
2822
2823 loop {
2824 match reader.read_event_into(&mut buf) {
2825 Ok(Event::Start(e)) => {
2826 let owned_e = e.into_owned();
2827 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2828 let is_empty = false;
2829
2830 #parsing_arms
2831 }
2832 Ok(Event::Empty(e)) => {
2833 let owned_e = e.into_owned();
2834 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2835 let is_empty = true;
2836
2837 #parsing_arms
2838 }
2839 Ok(Event::Eof) => break,
2840 Err(_) => {
2841 break;
2843 }
2844 _ => {}
2845 }
2846 buf.clear();
2847 }
2848
2849 Ok(actions)
2850 }
2851
2852 pub fn transform_actions<F>(&self, text: &str, mut transformer: F) -> String
2853 where
2854 F: FnMut(#enum_name) -> String,
2855 {
2856 use ::regex::Regex;
2857
2858 let regex_pattern = #tags_regex;
2859 if regex_pattern.is_empty() {
2860 return text.to_string();
2861 }
2862
2863 let re = Regex::new(®ex_pattern).unwrap_or_else(|e| {
2864 panic!("Failed to compile regex for action tags: {}", e);
2865 });
2866
2867 re.replace_all(text, |caps: &::regex::Captures| {
2868 let matched = caps.get(0).map(|m| m.as_str()).unwrap_or("");
2869
2870 if let Some(action) = self.parse_single_action(matched) {
2872 transformer(action)
2873 } else {
2874 matched.to_string()
2876 }
2877 }).to_string()
2878 }
2879
2880 pub fn strip_actions(&self, text: &str) -> String {
2881 self.transform_actions(text, |_| String::new())
2882 }
2883 }
2884 };
2885
2886 TokenStream::from(expanded)
2887}
2888
2889fn generate_parsing_arms(
2891 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2892 enum_name: &syn::Ident,
2893) -> proc_macro2::TokenStream {
2894 let mut arms = Vec::new();
2895
2896 for variant in variants {
2897 let variant_name = &variant.ident;
2898 let action_attrs = parse_action_attrs(&variant.attrs);
2899
2900 if let Some(tag) = action_attrs.tag {
2901 match &variant.fields {
2902 syn::Fields::Unit => {
2903 arms.push(quote! {
2905 if &tag_name == #tag {
2906 actions.push(#enum_name::#variant_name);
2907 }
2908 });
2909 }
2910 syn::Fields::Unnamed(_fields) => {
2911 arms.push(quote! {
2913 if &tag_name == #tag && !is_empty {
2914 match reader.read_text(owned_e.name()) {
2916 Ok(text) => {
2917 actions.push(#enum_name::#variant_name(text.to_string()));
2918 }
2919 Err(_) => {
2920 actions.push(#enum_name::#variant_name(String::new()));
2922 }
2923 }
2924 }
2925 });
2926 }
2927 syn::Fields::Named(fields) => {
2928 let mut field_names = Vec::new();
2930 let mut has_inner_text_field = None;
2931
2932 for field in &fields.named {
2933 let field_name = field.ident.as_ref().unwrap();
2934 let field_attrs = parse_field_action_attrs(&field.attrs);
2935
2936 if field_attrs.is_attribute {
2937 field_names.push(field_name.clone());
2938 } else if field_attrs.is_inner_text {
2939 has_inner_text_field = Some(field_name.clone());
2940 }
2941 }
2942
2943 if let Some(inner_text_field) = has_inner_text_field {
2944 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2947 quote! {
2948 let mut #field_name = String::new();
2949 for attr in owned_e.attributes() {
2950 if let Ok(attr) = attr {
2951 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2952 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2953 break;
2954 }
2955 }
2956 }
2957 }
2958 }).collect();
2959
2960 arms.push(quote! {
2961 if &tag_name == #tag {
2962 #(#attr_extractions)*
2963
2964 if is_empty {
2966 let #inner_text_field = String::new();
2967 actions.push(#enum_name::#variant_name {
2968 #(#field_names,)*
2969 #inner_text_field,
2970 });
2971 } else {
2972 match reader.read_text(owned_e.name()) {
2974 Ok(text) => {
2975 let #inner_text_field = text.to_string();
2976 actions.push(#enum_name::#variant_name {
2977 #(#field_names,)*
2978 #inner_text_field,
2979 });
2980 }
2981 Err(_) => {
2982 let #inner_text_field = String::new();
2984 actions.push(#enum_name::#variant_name {
2985 #(#field_names,)*
2986 #inner_text_field,
2987 });
2988 }
2989 }
2990 }
2991 }
2992 });
2993 } else {
2994 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2996 quote! {
2997 let mut #field_name = String::new();
2998 for attr in owned_e.attributes() {
2999 if let Ok(attr) = attr {
3000 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
3001 #field_name = String::from_utf8_lossy(&attr.value).to_string();
3002 break;
3003 }
3004 }
3005 }
3006 }
3007 }).collect();
3008
3009 arms.push(quote! {
3010 if &tag_name == #tag {
3011 #(#attr_extractions)*
3012 actions.push(#enum_name::#variant_name {
3013 #(#field_names),*
3014 });
3015 }
3016 });
3017 }
3018 }
3019 }
3020 }
3021 }
3022
3023 quote! {
3024 #(#arms)*
3025 }
3026}
3027
3028#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
3030pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
3031 let input = parse_macro_input!(input as DeriveInput);
3032
3033 let found_crate =
3034 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3035 let crate_path = match found_crate {
3036 FoundCrate::Itself => {
3037 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3039 quote!(::#ident)
3040 }
3041 FoundCrate::Name(name) => {
3042 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3043 quote!(::#ident)
3044 }
3045 };
3046
3047 let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
3049
3050 let struct_name = &input.ident;
3051 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
3052
3053 let placeholders = parse_template_placeholders_with_mode(&template);
3055
3056 let mut converted_template = template.clone();
3058 let mut context_fields = Vec::new();
3059
3060 let fields = match &input.data {
3062 Data::Struct(data_struct) => match &data_struct.fields {
3063 syn::Fields::Named(fields) => &fields.named,
3064 _ => panic!("ToPromptFor is only supported for structs with named fields"),
3065 },
3066 _ => panic!("ToPromptFor is only supported for structs"),
3067 };
3068
3069 let has_mode_support = input.attrs.iter().any(|attr| {
3071 if attr.path().is_ident("prompt")
3072 && let Ok(metas) =
3073 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
3074 {
3075 for meta in metas {
3076 if let Meta::NameValue(nv) = meta
3077 && nv.path.is_ident("mode")
3078 {
3079 return true;
3080 }
3081 }
3082 }
3083 false
3084 });
3085
3086 for (placeholder_name, mode_opt) in &placeholders {
3088 if placeholder_name == "self" {
3089 if let Some(specific_mode) = mode_opt {
3090 let unique_key = format!("self__{}", specific_mode);
3092
3093 let pattern = format!("{{{{ self:{} }}}}", specific_mode);
3095 let replacement = format!("{{{{ {} }}}}", unique_key);
3096 converted_template = converted_template.replace(&pattern, &replacement);
3097
3098 context_fields.push(quote! {
3100 context.insert(
3101 #unique_key.to_string(),
3102 minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
3103 );
3104 });
3105 } else {
3106 if has_mode_support {
3109 context_fields.push(quote! {
3111 context.insert(
3112 "self".to_string(),
3113 minijinja::Value::from(self.to_prompt_with_mode(mode))
3114 );
3115 });
3116 } else {
3117 context_fields.push(quote! {
3119 context.insert(
3120 "self".to_string(),
3121 minijinja::Value::from(self.to_prompt())
3122 );
3123 });
3124 }
3125 }
3126 } else {
3127 let field_exists = fields.iter().any(|f| {
3130 f.ident
3131 .as_ref()
3132 .is_some_and(|ident| ident == placeholder_name)
3133 });
3134
3135 if field_exists {
3136 let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
3137
3138 context_fields.push(quote! {
3142 context.insert(
3143 #placeholder_name.to_string(),
3144 minijinja::Value::from_serialize(&self.#field_ident)
3145 );
3146 });
3147 }
3148 }
3150 }
3151
3152 let expanded = quote! {
3153 impl #impl_generics #crate_path::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
3154 where
3155 #target_type: serde::Serialize,
3156 {
3157 fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
3158 let mut env = minijinja::Environment::new();
3160 env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
3161 panic!("Failed to parse template: {}", e)
3162 });
3163
3164 let tmpl = env.get_template("prompt").unwrap();
3165
3166 let mut context = std::collections::HashMap::new();
3168 context.insert(
3170 "self".to_string(),
3171 minijinja::Value::from_serialize(self)
3172 );
3173 context.insert(
3175 "target".to_string(),
3176 minijinja::Value::from_serialize(target)
3177 );
3178 #(#context_fields)*
3179
3180 tmpl.render(context).unwrap_or_else(|e| {
3182 format!("Failed to render prompt: {}", e)
3183 })
3184 }
3185 }
3186 };
3187
3188 TokenStream::from(expanded)
3189}
3190
3191struct AgentAttrs {
3197 expertise: Option<String>,
3198 output: Option<syn::Type>,
3199 backend: Option<String>,
3200 model: Option<String>,
3201 inner: Option<String>,
3202 default_inner: Option<String>,
3203 max_retries: Option<u32>,
3204 profile: Option<String>,
3205}
3206
3207impl Parse for AgentAttrs {
3208 fn parse(input: ParseStream) -> syn::Result<Self> {
3209 let mut expertise = None;
3210 let mut output = None;
3211 let mut backend = None;
3212 let mut model = None;
3213 let mut inner = None;
3214 let mut default_inner = None;
3215 let mut max_retries = None;
3216 let mut profile = None;
3217
3218 let pairs = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
3219
3220 for meta in pairs {
3221 match meta {
3222 Meta::NameValue(nv) if nv.path.is_ident("expertise") => {
3223 if let syn::Expr::Lit(syn::ExprLit {
3224 lit: syn::Lit::Str(lit_str),
3225 ..
3226 }) = &nv.value
3227 {
3228 expertise = Some(lit_str.value());
3229 }
3230 }
3231 Meta::NameValue(nv) if nv.path.is_ident("output") => {
3232 if let syn::Expr::Lit(syn::ExprLit {
3233 lit: syn::Lit::Str(lit_str),
3234 ..
3235 }) = &nv.value
3236 {
3237 let ty: syn::Type = syn::parse_str(&lit_str.value())?;
3238 output = Some(ty);
3239 }
3240 }
3241 Meta::NameValue(nv) if nv.path.is_ident("backend") => {
3242 if let syn::Expr::Lit(syn::ExprLit {
3243 lit: syn::Lit::Str(lit_str),
3244 ..
3245 }) = &nv.value
3246 {
3247 backend = Some(lit_str.value());
3248 }
3249 }
3250 Meta::NameValue(nv) if nv.path.is_ident("model") => {
3251 if let syn::Expr::Lit(syn::ExprLit {
3252 lit: syn::Lit::Str(lit_str),
3253 ..
3254 }) = &nv.value
3255 {
3256 model = Some(lit_str.value());
3257 }
3258 }
3259 Meta::NameValue(nv) if nv.path.is_ident("inner") => {
3260 if let syn::Expr::Lit(syn::ExprLit {
3261 lit: syn::Lit::Str(lit_str),
3262 ..
3263 }) = &nv.value
3264 {
3265 inner = Some(lit_str.value());
3266 }
3267 }
3268 Meta::NameValue(nv) if nv.path.is_ident("default_inner") => {
3269 if let syn::Expr::Lit(syn::ExprLit {
3270 lit: syn::Lit::Str(lit_str),
3271 ..
3272 }) = &nv.value
3273 {
3274 default_inner = Some(lit_str.value());
3275 }
3276 }
3277 Meta::NameValue(nv) if nv.path.is_ident("max_retries") => {
3278 if let syn::Expr::Lit(syn::ExprLit {
3279 lit: syn::Lit::Int(lit_int),
3280 ..
3281 }) = &nv.value
3282 {
3283 max_retries = Some(lit_int.base10_parse()?);
3284 }
3285 }
3286 Meta::NameValue(nv) if nv.path.is_ident("profile") => {
3287 if let syn::Expr::Lit(syn::ExprLit {
3288 lit: syn::Lit::Str(lit_str),
3289 ..
3290 }) = &nv.value
3291 {
3292 profile = Some(lit_str.value());
3293 }
3294 }
3295 _ => {}
3296 }
3297 }
3298
3299 Ok(AgentAttrs {
3300 expertise,
3301 output,
3302 backend,
3303 model,
3304 inner,
3305 default_inner,
3306 max_retries,
3307 profile,
3308 })
3309 }
3310}
3311
3312fn parse_agent_attrs(attrs: &[syn::Attribute]) -> syn::Result<AgentAttrs> {
3314 for attr in attrs {
3315 if attr.path().is_ident("agent") {
3316 return attr.parse_args::<AgentAttrs>();
3317 }
3318 }
3319
3320 Ok(AgentAttrs {
3321 expertise: None,
3322 output: None,
3323 backend: None,
3324 model: None,
3325 inner: None,
3326 default_inner: None,
3327 max_retries: None,
3328 profile: None,
3329 })
3330}
3331
3332fn generate_backend_constructors(
3334 struct_name: &syn::Ident,
3335 backend: &str,
3336 _model: Option<&str>,
3337 _profile: Option<&str>,
3338 crate_path: &proc_macro2::TokenStream,
3339) -> proc_macro2::TokenStream {
3340 match backend {
3341 "claude" => {
3342 quote! {
3343 impl #struct_name {
3344 pub fn with_claude() -> Self {
3346 Self::new(#crate_path::agent::impls::ClaudeCodeAgent::new())
3347 }
3348
3349 pub fn with_claude_model(model: &str) -> Self {
3351 Self::new(
3352 #crate_path::agent::impls::ClaudeCodeAgent::new()
3353 .with_model_str(model)
3354 )
3355 }
3356 }
3357 }
3358 }
3359 "gemini" => {
3360 quote! {
3361 impl #struct_name {
3362 pub fn with_gemini() -> Self {
3364 Self::new(#crate_path::agent::impls::GeminiAgent::new())
3365 }
3366
3367 pub fn with_gemini_model(model: &str) -> Self {
3369 Self::new(
3370 #crate_path::agent::impls::GeminiAgent::new()
3371 .with_model_str(model)
3372 )
3373 }
3374 }
3375 }
3376 }
3377 _ => quote! {},
3378 }
3379}
3380
3381fn generate_default_impl(
3383 struct_name: &syn::Ident,
3384 backend: &str,
3385 model: Option<&str>,
3386 profile: Option<&str>,
3387 crate_path: &proc_macro2::TokenStream,
3388) -> proc_macro2::TokenStream {
3389 let profile_expr = if let Some(profile_str) = profile {
3391 match profile_str.to_lowercase().as_str() {
3392 "creative" => quote! { #crate_path::agent::ExecutionProfile::Creative },
3393 "balanced" => quote! { #crate_path::agent::ExecutionProfile::Balanced },
3394 "deterministic" => quote! { #crate_path::agent::ExecutionProfile::Deterministic },
3395 _ => quote! { #crate_path::agent::ExecutionProfile::Balanced }, }
3397 } else {
3398 quote! { #crate_path::agent::ExecutionProfile::default() }
3399 };
3400
3401 let agent_init = match backend {
3402 "gemini" => {
3403 let mut builder = quote! { #crate_path::agent::impls::GeminiAgent::new() };
3404
3405 if let Some(model_str) = model {
3406 builder = quote! { #builder.with_model_str(#model_str) };
3407 }
3408
3409 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3410 builder
3411 }
3412 _ => {
3413 let mut builder = quote! { #crate_path::agent::impls::ClaudeCodeAgent::new() };
3415
3416 if let Some(model_str) = model {
3417 builder = quote! { #builder.with_model_str(#model_str) };
3418 }
3419
3420 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3421 builder
3422 }
3423 };
3424
3425 quote! {
3426 impl Default for #struct_name {
3427 fn default() -> Self {
3428 Self::new(#agent_init)
3429 }
3430 }
3431 }
3432}
3433
3434#[proc_macro_derive(Agent, attributes(agent))]
3443pub fn derive_agent(input: TokenStream) -> TokenStream {
3444 let input = parse_macro_input!(input as DeriveInput);
3445 let struct_name = &input.ident;
3446
3447 let agent_attrs = match parse_agent_attrs(&input.attrs) {
3449 Ok(attrs) => attrs,
3450 Err(e) => return e.to_compile_error().into(),
3451 };
3452
3453 let expertise = agent_attrs
3454 .expertise
3455 .unwrap_or_else(|| String::from("general AI assistant"));
3456 let output_type = agent_attrs
3457 .output
3458 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3459 let backend = agent_attrs
3460 .backend
3461 .unwrap_or_else(|| String::from("claude"));
3462 let model = agent_attrs.model;
3463 let _profile = agent_attrs.profile; let max_retries = agent_attrs.max_retries.unwrap_or(3); let found_crate =
3468 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3469 let crate_path = match found_crate {
3470 FoundCrate::Itself => {
3471 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3473 quote!(::#ident)
3474 }
3475 FoundCrate::Name(name) => {
3476 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3477 quote!(::#ident)
3478 }
3479 };
3480
3481 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
3482
3483 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3485 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3486
3487 let enhanced_expertise = if is_string_output {
3489 quote! { #expertise }
3491 } else {
3492 let type_name = quote!(#output_type).to_string();
3494 quote! {
3495 {
3496 use std::sync::OnceLock;
3497 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3498
3499 EXPERTISE_CACHE.get_or_init(|| {
3500 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3502
3503 if schema.is_empty() {
3504 format!(
3506 concat!(
3507 #expertise,
3508 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3509 "Do not include any text outside the JSON object."
3510 ),
3511 #type_name
3512 )
3513 } else {
3514 format!(
3516 concat!(
3517 #expertise,
3518 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3519 ),
3520 schema
3521 )
3522 }
3523 }).as_str()
3524 }
3525 }
3526 };
3527
3528 let agent_init = match backend.as_str() {
3530 "gemini" => {
3531 if let Some(model_str) = model {
3532 quote! {
3533 use #crate_path::agent::impls::GeminiAgent;
3534 let agent = GeminiAgent::new().with_model_str(#model_str);
3535 }
3536 } else {
3537 quote! {
3538 use #crate_path::agent::impls::GeminiAgent;
3539 let agent = GeminiAgent::new();
3540 }
3541 }
3542 }
3543 "claude" => {
3544 if let Some(model_str) = model {
3545 quote! {
3546 use #crate_path::agent::impls::ClaudeCodeAgent;
3547 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3548 }
3549 } else {
3550 quote! {
3551 use #crate_path::agent::impls::ClaudeCodeAgent;
3552 let agent = ClaudeCodeAgent::new();
3553 }
3554 }
3555 }
3556 _ => {
3557 if let Some(model_str) = model {
3559 quote! {
3560 use #crate_path::agent::impls::ClaudeCodeAgent;
3561 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3562 }
3563 } else {
3564 quote! {
3565 use #crate_path::agent::impls::ClaudeCodeAgent;
3566 let agent = ClaudeCodeAgent::new();
3567 }
3568 }
3569 }
3570 };
3571
3572 let expanded = quote! {
3573 #[async_trait::async_trait]
3574 impl #impl_generics #crate_path::agent::Agent for #struct_name #ty_generics #where_clause {
3575 type Output = #output_type;
3576
3577 fn expertise(&self) -> &str {
3578 #enhanced_expertise
3579 }
3580
3581 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3582 #agent_init
3584
3585 let agent_ref = &agent;
3587 #crate_path::agent::retry::retry_execution(
3588 #max_retries,
3589 &intent,
3590 move |payload| {
3591 let payload = payload.clone();
3592 async move {
3593 let response = agent_ref.execute(payload).await?;
3595
3596 let json_str = #crate_path::extract_json(&response)
3598 .map_err(|e| #crate_path::agent::AgentError::ParseError {
3599 message: format!("Failed to extract JSON: {}", e),
3600 reason: #crate_path::agent::error::ParseErrorReason::MarkdownExtractionFailed,
3601 })?;
3602
3603 serde_json::from_str::<Self::Output>(&json_str)
3605 .map_err(|e| {
3606 let reason = if e.is_eof() {
3608 #crate_path::agent::error::ParseErrorReason::UnexpectedEof
3609 } else if e.is_syntax() {
3610 #crate_path::agent::error::ParseErrorReason::InvalidJson
3611 } else {
3612 #crate_path::agent::error::ParseErrorReason::SchemaMismatch
3613 };
3614
3615 #crate_path::agent::AgentError::ParseError {
3616 message: format!("Failed to parse JSON: {}", e),
3617 reason,
3618 }
3619 })
3620 }
3621 }
3622 ).await
3623 }
3624
3625 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3626 #agent_init
3628 agent.is_available().await
3629 }
3630 }
3631 };
3632
3633 TokenStream::from(expanded)
3634}
3635
3636#[proc_macro_attribute]
3651pub fn agent(attr: TokenStream, item: TokenStream) -> TokenStream {
3652 let agent_attrs = match syn::parse::<AgentAttrs>(attr) {
3654 Ok(attrs) => attrs,
3655 Err(e) => return e.to_compile_error().into(),
3656 };
3657
3658 let input = parse_macro_input!(item as DeriveInput);
3660 let struct_name = &input.ident;
3661 let vis = &input.vis;
3662
3663 let expertise = agent_attrs
3664 .expertise
3665 .unwrap_or_else(|| String::from("general AI assistant"));
3666 let output_type = agent_attrs
3667 .output
3668 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3669 let backend = agent_attrs
3670 .backend
3671 .unwrap_or_else(|| String::from("claude"));
3672 let model = agent_attrs.model;
3673 let profile = agent_attrs.profile;
3674
3675 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3677 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3678
3679 let found_crate =
3681 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3682 let crate_path = match found_crate {
3683 FoundCrate::Itself => {
3684 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3685 quote!(::#ident)
3686 }
3687 FoundCrate::Name(name) => {
3688 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3689 quote!(::#ident)
3690 }
3691 };
3692
3693 let inner_generic_name = agent_attrs.inner.unwrap_or_else(|| String::from("A"));
3695 let inner_generic_ident = syn::Ident::new(&inner_generic_name, proc_macro2::Span::call_site());
3696
3697 let default_agent_type = if let Some(ref custom_type) = agent_attrs.default_inner {
3699 let type_path: syn::Type =
3701 syn::parse_str(custom_type).expect("default_inner must be a valid type path");
3702 quote! { #type_path }
3703 } else {
3704 match backend.as_str() {
3706 "gemini" => quote! { #crate_path::agent::impls::GeminiAgent },
3707 _ => quote! { #crate_path::agent::impls::ClaudeCodeAgent },
3708 }
3709 };
3710
3711 let struct_def = quote! {
3713 #vis struct #struct_name<#inner_generic_ident = #default_agent_type> {
3714 inner: #inner_generic_ident,
3715 }
3716 };
3717
3718 let constructors = quote! {
3720 impl<#inner_generic_ident> #struct_name<#inner_generic_ident> {
3721 pub fn new(inner: #inner_generic_ident) -> Self {
3723 Self { inner }
3724 }
3725 }
3726 };
3727
3728 let (backend_constructors, default_impl) = if agent_attrs.default_inner.is_some() {
3730 let default_impl = quote! {
3732 impl Default for #struct_name {
3733 fn default() -> Self {
3734 Self {
3735 inner: <#default_agent_type as Default>::default(),
3736 }
3737 }
3738 }
3739 };
3740 (quote! {}, default_impl)
3741 } else {
3742 let backend_constructors = generate_backend_constructors(
3744 struct_name,
3745 &backend,
3746 model.as_deref(),
3747 profile.as_deref(),
3748 &crate_path,
3749 );
3750 let default_impl = generate_default_impl(
3751 struct_name,
3752 &backend,
3753 model.as_deref(),
3754 profile.as_deref(),
3755 &crate_path,
3756 );
3757 (backend_constructors, default_impl)
3758 };
3759
3760 let enhanced_expertise = if is_string_output {
3762 quote! { #expertise }
3764 } else {
3765 let type_name = quote!(#output_type).to_string();
3767 quote! {
3768 {
3769 use std::sync::OnceLock;
3770 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3771
3772 EXPERTISE_CACHE.get_or_init(|| {
3773 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3775
3776 if schema.is_empty() {
3777 format!(
3779 concat!(
3780 #expertise,
3781 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3782 "Do not include any text outside the JSON object."
3783 ),
3784 #type_name
3785 )
3786 } else {
3787 format!(
3789 concat!(
3790 #expertise,
3791 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3792 ),
3793 schema
3794 )
3795 }
3796 }).as_str()
3797 }
3798 }
3799 };
3800
3801 let agent_impl = quote! {
3803 #[async_trait::async_trait]
3804 impl<#inner_generic_ident> #crate_path::agent::Agent for #struct_name<#inner_generic_ident>
3805 where
3806 #inner_generic_ident: #crate_path::agent::Agent<Output = String>,
3807 {
3808 type Output = #output_type;
3809
3810 fn expertise(&self) -> &str {
3811 #enhanced_expertise
3812 }
3813
3814 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3815 let enhanced_payload = intent.prepend_text(self.expertise());
3817
3818 let response = self.inner.execute(enhanced_payload).await?;
3820
3821 let json_str = #crate_path::extract_json(&response)
3823 .map_err(|e| #crate_path::agent::AgentError::ParseError {
3824 message: e.to_string(),
3825 reason: #crate_path::agent::error::ParseErrorReason::MarkdownExtractionFailed,
3826 })?;
3827
3828 serde_json::from_str(&json_str).map_err(|e| {
3830 let reason = if e.is_eof() {
3831 #crate_path::agent::error::ParseErrorReason::UnexpectedEof
3832 } else if e.is_syntax() {
3833 #crate_path::agent::error::ParseErrorReason::InvalidJson
3834 } else {
3835 #crate_path::agent::error::ParseErrorReason::SchemaMismatch
3836 };
3837 #crate_path::agent::AgentError::ParseError {
3838 message: e.to_string(),
3839 reason,
3840 }
3841 })
3842 }
3843
3844 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3845 self.inner.is_available().await
3846 }
3847 }
3848 };
3849
3850 let expanded = quote! {
3851 #struct_def
3852 #constructors
3853 #backend_constructors
3854 #default_impl
3855 #agent_impl
3856 };
3857
3858 TokenStream::from(expanded)
3859}
3860
3861#[proc_macro_derive(TypeMarker)]
3883pub fn derive_type_marker(input: TokenStream) -> TokenStream {
3884 let input = parse_macro_input!(input as DeriveInput);
3885 let struct_name = &input.ident;
3886 let type_name_str = struct_name.to_string();
3887
3888 let found_crate =
3890 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3891 let crate_path = match found_crate {
3892 FoundCrate::Itself => {
3893 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3894 quote!(::#ident)
3895 }
3896 FoundCrate::Name(name) => {
3897 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3898 quote!(::#ident)
3899 }
3900 };
3901
3902 let expanded = quote! {
3903 impl #crate_path::orchestrator::TypeMarker for #struct_name {
3904 const TYPE_NAME: &'static str = #type_name_str;
3905 }
3906 };
3907
3908 TokenStream::from(expanded)
3909}
3910
3911#[proc_macro_attribute]
3947pub fn type_marker(_attr: TokenStream, item: TokenStream) -> TokenStream {
3948 let input = parse_macro_input!(item as syn::DeriveInput);
3949 let struct_name = &input.ident;
3950 let vis = &input.vis;
3951 let type_name_str = struct_name.to_string();
3952
3953 let default_fn_name = syn::Ident::new(
3955 &format!("default_{}_type", to_snake_case(&type_name_str)),
3956 struct_name.span(),
3957 );
3958
3959 let found_crate =
3961 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3962 let crate_path = match found_crate {
3963 FoundCrate::Itself => {
3964 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3965 quote!(::#ident)
3966 }
3967 FoundCrate::Name(name) => {
3968 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3969 quote!(::#ident)
3970 }
3971 };
3972
3973 let fields = match &input.data {
3975 syn::Data::Struct(data_struct) => match &data_struct.fields {
3976 syn::Fields::Named(fields) => &fields.named,
3977 _ => {
3978 return syn::Error::new_spanned(
3979 struct_name,
3980 "type_marker only works with structs with named fields",
3981 )
3982 .to_compile_error()
3983 .into();
3984 }
3985 },
3986 _ => {
3987 return syn::Error::new_spanned(struct_name, "type_marker only works with structs")
3988 .to_compile_error()
3989 .into();
3990 }
3991 };
3992
3993 let mut new_fields = vec![];
3995
3996 let default_fn_name_str = default_fn_name.to_string();
3998 let default_fn_name_lit = syn::LitStr::new(&default_fn_name_str, default_fn_name.span());
3999
4000 new_fields.push(quote! {
4005 #[serde(default = #default_fn_name_lit)]
4006 __type: String
4007 });
4008
4009 for field in fields {
4011 new_fields.push(quote! { #field });
4012 }
4013
4014 let attrs = &input.attrs;
4016 let generics = &input.generics;
4017
4018 let expanded = quote! {
4019 fn #default_fn_name() -> String {
4021 #type_name_str.to_string()
4022 }
4023
4024 #(#attrs)*
4026 #vis struct #struct_name #generics {
4027 #(#new_fields),*
4028 }
4029
4030 impl #crate_path::orchestrator::TypeMarker for #struct_name {
4032 const TYPE_NAME: &'static str = #type_name_str;
4033 }
4034 };
4035
4036 TokenStream::from(expanded)
4037}