1use proc_macro::TokenStream;
2use proc_macro_crate::{FoundCrate, crate_name};
3use quote::quote;
4use regex::Regex;
5use syn::{
6 Data, DeriveInput, Meta, Token,
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 punctuated::Punctuated,
10};
11
12fn parse_template_placeholders_with_mode(template: &str) -> Vec<(String, Option<String>)> {
15 let mut placeholders = Vec::new();
16 let mut seen_fields = std::collections::HashSet::new();
17
18 let mode_pattern = Regex::new(r"\{\{\s*(\w+)\s*:\s*(\w+)\s*\}\}").unwrap();
20 for cap in mode_pattern.captures_iter(template) {
21 let field_name = cap[1].to_string();
22 let mode = cap[2].to_string();
23 placeholders.push((field_name.clone(), Some(mode)));
24 seen_fields.insert(field_name);
25 }
26
27 let standard_pattern = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
29 for cap in standard_pattern.captures_iter(template) {
30 let field_name = cap[1].to_string();
31 if !seen_fields.contains(&field_name) {
33 placeholders.push((field_name, None));
34 }
35 }
36
37 placeholders
38}
39
40fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
42 attrs
43 .iter()
44 .filter_map(|attr| {
45 if attr.path().is_ident("doc")
46 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
47 && let syn::Expr::Lit(syn::ExprLit {
48 lit: syn::Lit::Str(lit_str),
49 ..
50 }) = &meta_name_value.value
51 {
52 return Some(lit_str.value());
53 }
54 None
55 })
56 .map(|s| s.trim().to_string())
57 .collect::<Vec<_>>()
58 .join(" ")
59}
60
61fn generate_example_only_parts(
63 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
64 has_default: bool,
65 crate_path: &proc_macro2::TokenStream,
66) -> proc_macro2::TokenStream {
67 let mut field_values = Vec::new();
68
69 for field in fields.iter() {
70 let field_name = field.ident.as_ref().unwrap();
71 let field_name_str = field_name.to_string();
72 let attrs = parse_field_prompt_attrs(&field.attrs);
73
74 if field_name_str == "__type" {
78 continue;
79 }
80
81 if attrs.skip {
83 continue;
84 }
85
86 if let Some(example) = attrs.example {
88 field_values.push(quote! {
90 json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
91 });
92 } else if has_default {
93 field_values.push(quote! {
95 let default_value = serde_json::to_value(&default_instance.#field_name)
96 .unwrap_or(serde_json::Value::Null);
97 json_obj.insert(#field_name_str.to_string(), default_value);
98 });
99 } else {
100 field_values.push(quote! {
102 let value = serde_json::to_value(&self.#field_name)
103 .unwrap_or(serde_json::Value::Null);
104 json_obj.insert(#field_name_str.to_string(), value);
105 });
106 }
107 }
108
109 if has_default {
110 quote! {
111 {
112 let default_instance = Self::default();
113 let mut json_obj = serde_json::Map::new();
114 #(#field_values)*
115 let json_value = serde_json::Value::Object(json_obj);
116 let json_str = serde_json::to_string_pretty(&json_value)
117 .unwrap_or_else(|_| "{}".to_string());
118 vec![#crate_path::prompt::PromptPart::Text(json_str)]
119 }
120 }
121 } else {
122 quote! {
123 {
124 let mut json_obj = serde_json::Map::new();
125 #(#field_values)*
126 let json_value = serde_json::Value::Object(json_obj);
127 let json_str = serde_json::to_string_pretty(&json_value)
128 .unwrap_or_else(|_| "{}".to_string());
129 vec![#crate_path::prompt::PromptPart::Text(json_str)]
130 }
131 }
132 }
133}
134
135fn generate_schema_only_parts(
137 struct_name: &str,
138 struct_docs: &str,
139 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
140 crate_path: &proc_macro2::TokenStream,
141 _has_type_marker: bool,
142) -> proc_macro2::TokenStream {
143 let mut field_schema_parts = vec![];
144 let mut nested_type_collectors = vec![];
145
146 for field in fields.iter() {
148 let field_name = field.ident.as_ref().unwrap();
149 let field_name_str = field_name.to_string();
150 let attrs = parse_field_prompt_attrs(&field.attrs);
151
152 if field_name_str == "__type" {
156 continue;
157 }
158
159 if attrs.skip {
161 continue;
162 }
163
164 let field_docs = extract_doc_comments(&field.attrs);
166
167 let (is_vec, vec_inner_type) = extract_vec_inner_type(&field.ty);
169 let (is_option, option_inner_type) = extract_option_inner_type(&field.ty);
170 let (is_map, map_value_type) = extract_map_value_type(&field.ty);
171 let (is_set, set_element_type) = extract_set_element_type(&field.ty);
172
173 if is_vec {
174 let comment = if !field_docs.is_empty() {
177 format!(" // {}", field_docs)
178 } else {
179 String::new()
180 };
181
182 field_schema_parts.push(quote! {
183 {
184 let type_name = stringify!(#vec_inner_type);
185 format!(" {}: {}[];{}", #field_name_str, type_name, #comment)
186 }
187 });
188
189 if let Some(inner) = vec_inner_type
191 && !is_primitive_type(inner)
192 {
193 nested_type_collectors.push(quote! {
194 <#inner as #crate_path::prompt::ToPrompt>::prompt_schema()
195 });
196 }
197 } else if is_option {
198 let comment = if !field_docs.is_empty() {
201 format!(" // {}", field_docs)
202 } else {
203 String::new()
204 };
205
206 if let Some(inner) = option_inner_type {
207 let (inner_is_map, inner_map_value) = extract_map_value_type(inner);
209 let (inner_is_set, inner_set_element) = extract_set_element_type(inner);
210 let (inner_is_vec, inner_vec_element) = extract_vec_inner_type(inner);
211
212 if inner_is_map {
213 if let Some(value_type) = inner_map_value {
215 let is_value_primitive = is_primitive_type(value_type);
216 if !is_value_primitive {
217 field_schema_parts.push(quote! {
218 {
219 let type_name = stringify!(#value_type);
220 format!(" {}: Record<string, {}> | null;{}", #field_name_str, type_name, #comment)
221 }
222 });
223 nested_type_collectors.push(quote! {
224 <#value_type as #crate_path::prompt::ToPrompt>::prompt_schema()
225 });
226 } else {
227 let type_str = format_type_for_schema(value_type);
228 field_schema_parts.push(quote! {
229 format!(" {}: Record<string, {}> | null;{}", #field_name_str, #type_str, #comment)
230 });
231 }
232 }
233 } else if inner_is_set {
234 if let Some(element_type) = inner_set_element {
236 let is_element_primitive = is_primitive_type(element_type);
237 if !is_element_primitive {
238 field_schema_parts.push(quote! {
239 {
240 let type_name = stringify!(#element_type);
241 format!(" {}: {}[] | null;{}", #field_name_str, type_name, #comment)
242 }
243 });
244 nested_type_collectors.push(quote! {
245 <#element_type as #crate_path::prompt::ToPrompt>::prompt_schema()
246 });
247 } else {
248 let type_str = format_type_for_schema(element_type);
249 field_schema_parts.push(quote! {
250 format!(" {}: {}[] | null;{}", #field_name_str, #type_str, #comment)
251 });
252 }
253 }
254 } else if inner_is_vec {
255 if let Some(element_type) = inner_vec_element {
257 let is_element_primitive = is_primitive_type(element_type);
258 if !is_element_primitive {
259 field_schema_parts.push(quote! {
260 {
261 let type_name = stringify!(#element_type);
262 format!(" {}: {}[] | null;{}", #field_name_str, type_name, #comment)
263 }
264 });
265 nested_type_collectors.push(quote! {
266 <#element_type as #crate_path::prompt::ToPrompt>::prompt_schema()
267 });
268 } else {
269 let type_str = format_type_for_schema(element_type);
270 field_schema_parts.push(quote! {
271 format!(" {}: {}[] | null;{}", #field_name_str, #type_str, #comment)
272 });
273 }
274 }
275 } else {
276 let is_inner_primitive = is_primitive_type(inner);
278
279 if !is_inner_primitive {
280 field_schema_parts.push(quote! {
282 {
283 let type_name = stringify!(#inner);
284 format!(" {}: {} | null;{}", #field_name_str, type_name, #comment)
285 }
286 });
287
288 nested_type_collectors.push(quote! {
290 <#inner as #crate_path::prompt::ToPrompt>::prompt_schema()
291 });
292 } else {
293 let type_str = format_type_for_schema(inner);
295 field_schema_parts.push(quote! {
296 format!(" {}: {} | null;{}", #field_name_str, #type_str, #comment)
297 });
298 }
299 }
300 }
301 } else if is_map {
302 let comment = if !field_docs.is_empty() {
305 format!(" // {}", field_docs)
306 } else {
307 String::new()
308 };
309
310 if let Some(value_type) = map_value_type {
311 let is_value_primitive = is_primitive_type(value_type);
312
313 if !is_value_primitive {
314 field_schema_parts.push(quote! {
316 {
317 let type_name = stringify!(#value_type);
318 format!(" {}: Record<string, {}>;{}", #field_name_str, type_name, #comment)
319 }
320 });
321
322 nested_type_collectors.push(quote! {
324 <#value_type as #crate_path::prompt::ToPrompt>::prompt_schema()
325 });
326 } else {
327 let type_str = format_type_for_schema(value_type);
329 field_schema_parts.push(quote! {
330 format!(" {}: Record<string, {}>;{}", #field_name_str, #type_str, #comment)
331 });
332 }
333 }
334 } else if is_set {
335 let comment = if !field_docs.is_empty() {
338 format!(" // {}", field_docs)
339 } else {
340 String::new()
341 };
342
343 if let Some(element_type) = set_element_type {
344 let is_element_primitive = is_primitive_type(element_type);
345
346 if !is_element_primitive {
347 field_schema_parts.push(quote! {
349 {
350 let type_name = stringify!(#element_type);
351 format!(" {}: {}[];{}", #field_name_str, type_name, #comment)
352 }
353 });
354
355 nested_type_collectors.push(quote! {
357 <#element_type as #crate_path::prompt::ToPrompt>::prompt_schema()
358 });
359 } else {
360 let type_str = format_type_for_schema(element_type);
362 field_schema_parts.push(quote! {
363 format!(" {}: {}[];{}", #field_name_str, #type_str, #comment)
364 });
365 }
366 }
367 } else {
368 let field_type = &field.ty;
370 let is_primitive = is_primitive_type(field_type);
371
372 if !is_primitive {
373 let comment = if !field_docs.is_empty() {
376 format!(" // {}", field_docs)
377 } else {
378 String::new()
379 };
380
381 field_schema_parts.push(quote! {
382 {
383 let type_name = stringify!(#field_type);
384 format!(" {}: {};{}", #field_name_str, type_name, #comment)
385 }
386 });
387
388 nested_type_collectors.push(quote! {
390 <#field_type as #crate_path::prompt::ToPrompt>::prompt_schema()
391 });
392 } else {
393 let type_str = format_type_for_schema(&field.ty);
396 let comment = if !field_docs.is_empty() {
397 format!(" // {}", field_docs)
398 } else {
399 String::new()
400 };
401
402 field_schema_parts.push(quote! {
403 format!(" {}: {};{}", #field_name_str, #type_str, #comment)
404 });
405 }
406 }
407 }
408
409 let mut header_lines = Vec::new();
424
425 if !struct_docs.is_empty() {
427 header_lines.push("/**".to_string());
428 header_lines.push(format!(" * {}", struct_docs));
429 header_lines.push(" */".to_string());
430 }
431
432 header_lines.push(format!("type {} = {{", struct_name));
434
435 quote! {
436 {
437 let mut all_lines: Vec<String> = Vec::new();
438
439 let nested_schemas: Vec<String> = vec![#(#nested_type_collectors),*];
441 let mut seen_types = std::collections::HashSet::<String>::new();
442
443 for schema in nested_schemas {
444 if !schema.is_empty() {
445 if seen_types.insert(schema.clone()) {
447 all_lines.push(schema);
448 all_lines.push(String::new()); }
450 }
451 }
452
453 let mut lines: Vec<String> = Vec::new();
455 #(lines.push(#header_lines.to_string());)*
456 #(lines.push(#field_schema_parts);)*
457 lines.push("}".to_string());
458 all_lines.push(lines.join("\n"));
459
460 vec![#crate_path::prompt::PromptPart::Text(all_lines.join("\n"))]
461 }
462 }
463}
464
465fn extract_vec_inner_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
467 if let syn::Type::Path(type_path) = ty
468 && let Some(last_segment) = type_path.path.segments.last()
469 && last_segment.ident == "Vec"
470 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
471 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
472 {
473 return (true, Some(inner_type));
474 }
475 (false, None)
476}
477
478fn extract_option_inner_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
480 if let syn::Type::Path(type_path) = ty
481 && let Some(last_segment) = type_path.path.segments.last()
482 && last_segment.ident == "Option"
483 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
484 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
485 {
486 return (true, Some(inner_type));
487 }
488 (false, None)
489}
490
491fn extract_map_value_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
493 if let syn::Type::Path(type_path) = ty
494 && let Some(last_segment) = type_path.path.segments.last()
495 && (last_segment.ident == "HashMap" || last_segment.ident == "BTreeMap")
496 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
497 && let Some(syn::GenericArgument::Type(value_type)) = args.args.iter().nth(1)
498 {
499 return (true, Some(value_type));
500 }
501 (false, None)
502}
503
504fn extract_set_element_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
506 if let syn::Type::Path(type_path) = ty
507 && let Some(last_segment) = type_path.path.segments.last()
508 && (last_segment.ident == "HashSet" || last_segment.ident == "BTreeSet")
509 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
510 && let Some(syn::GenericArgument::Type(element_type)) = args.args.first()
511 {
512 return (true, Some(element_type));
513 }
514 (false, None)
515}
516
517fn extract_expandable_type(ty: &syn::Type) -> &syn::Type {
520 if let syn::Type::Path(type_path) = ty
521 && let Some(last_segment) = type_path.path.segments.last()
522 {
523 let type_name = last_segment.ident.to_string();
524
525 if type_name == "Option"
527 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
528 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
529 {
530 return extract_expandable_type(inner_type);
531 }
532
533 if type_name == "Vec"
535 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
536 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
537 {
538 return extract_expandable_type(inner_type);
539 }
540
541 if (type_name == "HashSet" || type_name == "BTreeSet")
543 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
544 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
545 {
546 return extract_expandable_type(inner_type);
547 }
548
549 if (type_name == "HashMap" || type_name == "BTreeMap")
552 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
553 && let Some(syn::GenericArgument::Type(value_type)) = args.args.iter().nth(1)
554 {
555 return extract_expandable_type(value_type);
556 }
557 }
558
559 ty
561}
562
563fn is_primitive_type(ty: &syn::Type) -> bool {
565 if let syn::Type::Path(type_path) = ty
566 && let Some(last_segment) = type_path.path.segments.last()
567 {
568 let type_name = last_segment.ident.to_string();
569
570 if type_name == "Option"
572 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
573 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
574 {
575 return is_primitive_type(inner_type);
576 }
577
578 if type_name == "Vec"
580 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
581 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
582 {
583 return is_primitive_type(inner_type);
584 }
585
586 if (type_name == "HashSet" || type_name == "BTreeSet")
588 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
589 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
590 {
591 return is_primitive_type(inner_type);
592 }
593
594 if (type_name == "HashMap" || type_name == "BTreeMap")
597 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
598 && let Some(syn::GenericArgument::Type(value_type)) = args.args.iter().nth(1)
599 {
600 return is_primitive_type(value_type);
601 }
602
603 matches!(
604 type_name.as_str(),
605 "String"
606 | "str"
607 | "i8"
608 | "i16"
609 | "i32"
610 | "i64"
611 | "i128"
612 | "isize"
613 | "u8"
614 | "u16"
615 | "u32"
616 | "u64"
617 | "u128"
618 | "usize"
619 | "f32"
620 | "f64"
621 | "bool"
622 )
623 } else {
624 true
626 }
627}
628
629fn format_type_for_schema(ty: &syn::Type) -> String {
631 match ty {
633 syn::Type::Path(type_path) => {
634 let path = &type_path.path;
635 if let Some(last_segment) = path.segments.last() {
636 let type_name = last_segment.ident.to_string();
637
638 if type_name == "Option"
640 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
641 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
642 {
643 return format!("{} | null", format_type_for_schema(inner_type));
644 }
645
646 match type_name.as_str() {
648 "String" | "str" => "string".to_string(),
649 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
650 | "u64" | "u128" | "usize" => "number".to_string(),
651 "f32" | "f64" => "number".to_string(),
652 "bool" => "boolean".to_string(),
653 "Vec" => {
654 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
655 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
656 {
657 return format!("{}[]", format_type_for_schema(inner_type));
658 }
659 "array".to_string()
660 }
661 _ => type_name,
663 }
664 } else {
665 "unknown".to_string()
666 }
667 }
668 _ => "unknown".to_string(),
669 }
670}
671
672#[derive(Default)]
674struct PromptAttributes {
675 skip: bool,
676 rename: Option<String>,
677 description: Option<String>,
678}
679
680fn parse_prompt_attributes(attrs: &[syn::Attribute]) -> PromptAttributes {
683 let mut result = PromptAttributes::default();
684
685 for attr in attrs {
686 if attr.path().is_ident("prompt") {
687 if let Ok(meta_list) = attr.meta.require_list() {
689 if let Ok(metas) =
691 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
692 {
693 for meta in metas {
694 if let Meta::NameValue(nv) = meta {
695 if nv.path.is_ident("rename") {
696 if let syn::Expr::Lit(syn::ExprLit {
697 lit: syn::Lit::Str(lit_str),
698 ..
699 }) = nv.value
700 {
701 result.rename = Some(lit_str.value());
702 }
703 } else if nv.path.is_ident("description")
704 && let syn::Expr::Lit(syn::ExprLit {
705 lit: syn::Lit::Str(lit_str),
706 ..
707 }) = nv.value
708 {
709 result.description = Some(lit_str.value());
710 }
711 } else if let Meta::Path(path) = meta
712 && path.is_ident("skip")
713 {
714 result.skip = true;
715 }
716 }
717 }
718
719 let tokens_str = meta_list.tokens.to_string();
721 if tokens_str == "skip" {
722 result.skip = true;
723 }
724 }
725
726 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
728 result.description = Some(lit_str.value());
729 }
730 }
731 }
732 result
733}
734
735fn generate_example_value_for_type(type_str: &str) -> String {
737 match type_str {
738 "string" => "\"example\"".to_string(),
739 "number" => "0".to_string(),
740 "boolean" => "false".to_string(),
741 s if s.ends_with("[]") => "[]".to_string(),
742 s if s.contains("|") => {
743 let first_type = s.split('|').next().unwrap().trim();
745 generate_example_value_for_type(first_type)
746 }
747 custom_type => {
748 format!("\"<See {}>\"", custom_type)
752 }
753 }
754}
755
756fn parse_serde_variant_rename(attrs: &[syn::Attribute]) -> Option<String> {
758 for attr in attrs {
759 if attr.path().is_ident("serde")
760 && let Ok(meta_list) = attr.meta.require_list()
761 && let Ok(metas) =
762 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
763 {
764 for meta in metas {
765 if let Meta::NameValue(nv) = meta
766 && nv.path.is_ident("rename")
767 && let syn::Expr::Lit(syn::ExprLit {
768 lit: syn::Lit::Str(lit_str),
769 ..
770 }) = nv.value
771 {
772 return Some(lit_str.value());
773 }
774 }
775 }
776 }
777 None
778}
779
780#[derive(Debug, Clone, Copy, PartialEq, Eq)]
782enum RenameRule {
783 #[allow(dead_code)]
784 None,
785 LowerCase,
786 UpperCase,
787 PascalCase,
788 CamelCase,
789 SnakeCase,
790 ScreamingSnakeCase,
791 KebabCase,
792 ScreamingKebabCase,
793}
794
795impl RenameRule {
796 fn from_str(s: &str) -> Option<Self> {
798 match s {
799 "lowercase" => Some(Self::LowerCase),
800 "UPPERCASE" => Some(Self::UpperCase),
801 "PascalCase" => Some(Self::PascalCase),
802 "camelCase" => Some(Self::CamelCase),
803 "snake_case" => Some(Self::SnakeCase),
804 "SCREAMING_SNAKE_CASE" => Some(Self::ScreamingSnakeCase),
805 "kebab-case" => Some(Self::KebabCase),
806 "SCREAMING-KEBAB-CASE" => Some(Self::ScreamingKebabCase),
807 _ => None,
808 }
809 }
810
811 fn apply(&self, name: &str) -> String {
813 match self {
814 Self::None => name.to_string(),
815 Self::LowerCase => name.to_lowercase(),
816 Self::UpperCase => name.to_uppercase(),
817 Self::PascalCase => name.to_string(), Self::CamelCase => {
819 let mut chars = name.chars();
821 match chars.next() {
822 None => String::new(),
823 Some(first) => first.to_lowercase().chain(chars).collect(),
824 }
825 }
826 Self::SnakeCase => {
827 let mut result = String::new();
829 for (i, ch) in name.chars().enumerate() {
830 if ch.is_uppercase() && i > 0 {
831 result.push('_');
832 }
833 result.push(ch.to_lowercase().next().unwrap());
834 }
835 result
836 }
837 Self::ScreamingSnakeCase => {
838 let mut result = String::new();
840 for (i, ch) in name.chars().enumerate() {
841 if ch.is_uppercase() && i > 0 {
842 result.push('_');
843 }
844 result.push(ch.to_uppercase().next().unwrap());
845 }
846 result
847 }
848 Self::KebabCase => {
849 let mut result = String::new();
851 for (i, ch) in name.chars().enumerate() {
852 if ch.is_uppercase() && i > 0 {
853 result.push('-');
854 }
855 result.push(ch.to_lowercase().next().unwrap());
856 }
857 result
858 }
859 Self::ScreamingKebabCase => {
860 let mut result = String::new();
862 for (i, ch) in name.chars().enumerate() {
863 if ch.is_uppercase() && i > 0 {
864 result.push('-');
865 }
866 result.push(ch.to_uppercase().next().unwrap());
867 }
868 result
869 }
870 }
871 }
872}
873
874fn parse_serde_rename_all(attrs: &[syn::Attribute]) -> Option<RenameRule> {
876 for attr in attrs {
877 if attr.path().is_ident("serde")
878 && let Ok(meta_list) = attr.meta.require_list()
879 {
880 if let Ok(metas) =
882 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
883 {
884 for meta in metas {
885 if let Meta::NameValue(nv) = meta
886 && nv.path.is_ident("rename_all")
887 && let syn::Expr::Lit(syn::ExprLit {
888 lit: syn::Lit::Str(lit_str),
889 ..
890 }) = nv.value
891 {
892 return RenameRule::from_str(&lit_str.value());
893 }
894 }
895 }
896 }
897 }
898 None
899}
900
901fn parse_serde_tag(attrs: &[syn::Attribute]) -> Option<String> {
904 for attr in attrs {
905 if attr.path().is_ident("serde")
906 && let Ok(meta_list) = attr.meta.require_list()
907 {
908 if let Ok(metas) =
910 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
911 {
912 for meta in metas {
913 if let Meta::NameValue(nv) = meta
914 && nv.path.is_ident("tag")
915 && let syn::Expr::Lit(syn::ExprLit {
916 lit: syn::Lit::Str(lit_str),
917 ..
918 }) = nv.value
919 {
920 return Some(lit_str.value());
921 }
922 }
923 }
924 }
925 }
926 None
927}
928
929fn parse_serde_untagged(attrs: &[syn::Attribute]) -> bool {
932 for attr in attrs {
933 if attr.path().is_ident("serde")
934 && let Ok(meta_list) = attr.meta.require_list()
935 {
936 if let Ok(metas) =
938 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
939 {
940 for meta in metas {
941 if let Meta::Path(path) = meta
942 && path.is_ident("untagged")
943 {
944 return true;
945 }
946 }
947 }
948 }
949 }
950 false
951}
952
953#[derive(Debug, Default)]
955struct FieldPromptAttrs {
956 skip: bool,
957 rename: Option<String>,
958 format_with: Option<String>,
959 image: bool,
960 example: Option<String>,
961}
962
963fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
965 let mut result = FieldPromptAttrs::default();
966
967 for attr in attrs {
968 if attr.path().is_ident("prompt") {
969 if let Ok(meta_list) = attr.meta.require_list() {
971 if let Ok(metas) =
973 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
974 {
975 for meta in metas {
976 match meta {
977 Meta::Path(path) if path.is_ident("skip") => {
978 result.skip = true;
979 }
980 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
981 if let syn::Expr::Lit(syn::ExprLit {
982 lit: syn::Lit::Str(lit_str),
983 ..
984 }) = nv.value
985 {
986 result.rename = Some(lit_str.value());
987 }
988 }
989 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
990 if let syn::Expr::Lit(syn::ExprLit {
991 lit: syn::Lit::Str(lit_str),
992 ..
993 }) = nv.value
994 {
995 result.format_with = Some(lit_str.value());
996 }
997 }
998 Meta::Path(path) if path.is_ident("image") => {
999 result.image = true;
1000 }
1001 Meta::NameValue(nv) if nv.path.is_ident("example") => {
1002 if let syn::Expr::Lit(syn::ExprLit {
1003 lit: syn::Lit::Str(lit_str),
1004 ..
1005 }) = nv.value
1006 {
1007 result.example = Some(lit_str.value());
1008 }
1009 }
1010 _ => {}
1011 }
1012 }
1013 } else if meta_list.tokens.to_string() == "skip" {
1014 result.skip = true;
1016 } else if meta_list.tokens.to_string() == "image" {
1017 result.image = true;
1019 }
1020 }
1021 }
1022 }
1023
1024 result
1025}
1026
1027#[proc_macro_derive(ToPrompt, attributes(prompt))]
1070pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
1071 let input = parse_macro_input!(input as DeriveInput);
1072
1073 let found_crate =
1074 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1075 let crate_path = match found_crate {
1076 FoundCrate::Itself => {
1077 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1079 quote!(::#ident)
1080 }
1081 FoundCrate::Name(name) => {
1082 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1083 quote!(::#ident)
1084 }
1085 };
1086
1087 match &input.data {
1089 Data::Enum(data_enum) => {
1090 let enum_name = &input.ident;
1092 let enum_docs = extract_doc_comments(&input.attrs);
1093
1094 let serde_tag = parse_serde_tag(&input.attrs);
1096 let is_internally_tagged = serde_tag.is_some();
1097 let is_untagged = parse_serde_untagged(&input.attrs);
1098
1099 let rename_rule = parse_serde_rename_all(&input.attrs);
1101
1102 let mut variant_lines = Vec::new();
1115 let mut first_variant_name = None;
1116
1117 let mut example_unit: Option<String> = None;
1119 let mut example_struct: Option<String> = None;
1120 let mut example_tuple: Option<String> = None;
1121
1122 let mut nested_types: Vec<&syn::Type> = Vec::new();
1124
1125 for variant in &data_enum.variants {
1126 let variant_name = &variant.ident;
1127 let variant_name_str = variant_name.to_string();
1128
1129 let prompt_attrs = parse_prompt_attributes(&variant.attrs);
1131
1132 if prompt_attrs.skip {
1134 continue;
1135 }
1136
1137 let variant_value = if let Some(prompt_rename) = &prompt_attrs.rename {
1143 prompt_rename.clone()
1144 } else if let Some(serde_rename) = parse_serde_variant_rename(&variant.attrs) {
1145 serde_rename
1146 } else if let Some(rule) = rename_rule {
1147 rule.apply(&variant_name_str)
1148 } else {
1149 variant_name_str.clone()
1150 };
1151
1152 let variant_line = match &variant.fields {
1154 syn::Fields::Unit => {
1155 if example_unit.is_none() {
1157 example_unit = Some(format!("\"{}\"", variant_value));
1158 }
1159
1160 if let Some(desc) = &prompt_attrs.description {
1162 format!(" | \"{}\" // {}", variant_value, desc)
1163 } else {
1164 let docs = extract_doc_comments(&variant.attrs);
1165 if !docs.is_empty() {
1166 format!(" | \"{}\" // {}", variant_value, docs)
1167 } else {
1168 format!(" | \"{}\"", variant_value)
1169 }
1170 }
1171 }
1172 syn::Fields::Named(fields) => {
1173 let mut field_parts = Vec::new();
1174 let mut example_field_parts = Vec::new();
1175
1176 if is_internally_tagged && let Some(tag_name) = &serde_tag {
1178 field_parts.push(format!("{}: \"{}\"", tag_name, variant_value));
1179 example_field_parts
1180 .push(format!("{}: \"{}\"", tag_name, variant_value));
1181 }
1182
1183 for field in &fields.named {
1184 let field_name = field.ident.as_ref().unwrap().to_string();
1185 let field_type = format_type_for_schema(&field.ty);
1186 field_parts.push(format!("{}: {}", field_name, field_type.clone()));
1187
1188 let expandable_type = extract_expandable_type(&field.ty);
1191 if !is_primitive_type(expandable_type) {
1192 nested_types.push(expandable_type);
1193 }
1194
1195 let example_value = generate_example_value_for_type(&field_type);
1197 example_field_parts.push(format!("{}: {}", field_name, example_value));
1198 }
1199
1200 let field_str = field_parts.join(", ");
1201 let example_field_str = example_field_parts.join(", ");
1202
1203 if example_struct.is_none() {
1205 if is_untagged || is_internally_tagged {
1206 example_struct = Some(format!("{{ {} }}", example_field_str));
1207 } else {
1208 example_struct = Some(format!(
1209 "{{ \"{}\": {{ {} }} }}",
1210 variant_value, example_field_str
1211 ));
1212 }
1213 }
1214
1215 let comment = if let Some(desc) = &prompt_attrs.description {
1216 format!(" // {}", desc)
1217 } else {
1218 let docs = extract_doc_comments(&variant.attrs);
1219 if !docs.is_empty() {
1220 format!(" // {}", docs)
1221 } else if is_untagged {
1222 format!(" // {}", variant_value)
1224 } else {
1225 String::new()
1226 }
1227 };
1228
1229 if is_untagged {
1230 format!(" | {{ {} }}{}", field_str, comment)
1232 } else if is_internally_tagged {
1233 format!(" | {{ {} }}{}", field_str, comment)
1235 } else {
1236 format!(
1238 " | {{ \"{}\": {{ {} }} }}{}",
1239 variant_value, field_str, comment
1240 )
1241 }
1242 }
1243 syn::Fields::Unnamed(fields) => {
1244 let field_types: Vec<String> = fields
1245 .unnamed
1246 .iter()
1247 .map(|f| {
1248 let expandable_type = extract_expandable_type(&f.ty);
1251 if !is_primitive_type(expandable_type) {
1252 nested_types.push(expandable_type);
1253 }
1254 format_type_for_schema(&f.ty)
1255 })
1256 .collect();
1257
1258 let tuple_str = field_types.join(", ");
1259
1260 let example_values: Vec<String> = field_types
1262 .iter()
1263 .map(|type_str| generate_example_value_for_type(type_str))
1264 .collect();
1265 let example_tuple_str = example_values.join(", ");
1266
1267 if example_tuple.is_none() {
1269 if is_untagged || is_internally_tagged {
1270 example_tuple = Some(format!("[{}]", example_tuple_str));
1271 } else {
1272 example_tuple = Some(format!(
1273 "{{ \"{}\": [{}] }}",
1274 variant_value, example_tuple_str
1275 ));
1276 }
1277 }
1278
1279 let comment = if let Some(desc) = &prompt_attrs.description {
1280 format!(" // {}", desc)
1281 } else {
1282 let docs = extract_doc_comments(&variant.attrs);
1283 if !docs.is_empty() {
1284 format!(" // {}", docs)
1285 } else if is_untagged {
1286 format!(" // {}", variant_value)
1288 } else {
1289 String::new()
1290 }
1291 };
1292
1293 if is_untagged || is_internally_tagged {
1294 format!(" | [{}]{}", tuple_str, comment)
1297 } else {
1298 format!(
1300 " | {{ \"{}\": [{}] }}{}",
1301 variant_value, tuple_str, comment
1302 )
1303 }
1304 }
1305 };
1306
1307 variant_lines.push(variant_line);
1308
1309 if first_variant_name.is_none() {
1310 first_variant_name = Some(variant_value);
1311 }
1312 }
1313
1314 let mut lines = Vec::new();
1316
1317 if !enum_docs.is_empty() {
1319 lines.push("/**".to_string());
1320 lines.push(format!(" * {}", enum_docs));
1321 lines.push(" */".to_string());
1322 }
1323
1324 lines.push(format!("type {} =", enum_name));
1326
1327 for line in &variant_lines {
1329 lines.push(line.clone());
1330 }
1331
1332 if let Some(last) = lines.last_mut()
1334 && !last.ends_with(';')
1335 {
1336 last.push(';');
1337 }
1338
1339 let mut examples = Vec::new();
1341 if let Some(ex) = example_unit {
1342 examples.push(ex);
1343 }
1344 if let Some(ex) = example_struct {
1345 examples.push(ex);
1346 }
1347 if let Some(ex) = example_tuple {
1348 examples.push(ex);
1349 }
1350
1351 if !examples.is_empty() {
1352 lines.push("".to_string()); if examples.len() == 1 {
1354 lines.push(format!("Example value: {}", examples[0]));
1355 } else {
1356 lines.push("Example values:".to_string());
1357 for ex in examples {
1358 lines.push(format!(" {}", ex));
1359 }
1360 }
1361 }
1362
1363 let nested_type_tokens: Vec<_> = nested_types
1365 .iter()
1366 .map(|field_ty| {
1367 quote! {
1368 {
1369 let type_schema = <#field_ty as #crate_path::prompt::ToPrompt>::prompt_schema();
1370 if !type_schema.is_empty() {
1371 format!("\n\n{}", type_schema)
1372 } else {
1373 String::new()
1374 }
1375 }
1376 }
1377 })
1378 .collect();
1379
1380 let prompt_string = if nested_type_tokens.is_empty() {
1381 let lines_str = lines.join("\n");
1382 quote! { #lines_str.to_string() }
1383 } else {
1384 let lines_str = lines.join("\n");
1385 quote! {
1386 {
1387 let mut result = String::from(#lines_str);
1388
1389 let nested_schemas: Vec<String> = vec![#(#nested_type_tokens),*];
1391 let mut seen_schemas = std::collections::HashSet::<String>::new();
1392
1393 for schema in nested_schemas {
1394 if !schema.is_empty() && seen_schemas.insert(schema.clone()) {
1395 result.push_str(&schema);
1396 }
1397 }
1398
1399 result
1400 }
1401 }
1402 };
1403 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1404
1405 let mut match_arms = Vec::new();
1407 for variant in &data_enum.variants {
1408 let variant_name = &variant.ident;
1409 let variant_name_str = variant_name.to_string();
1410
1411 let prompt_attrs = parse_prompt_attributes(&variant.attrs);
1413
1414 let variant_value = if let Some(prompt_rename) = &prompt_attrs.rename {
1420 prompt_rename.clone()
1421 } else if let Some(serde_rename) = parse_serde_variant_rename(&variant.attrs) {
1422 serde_rename
1423 } else if let Some(rule) = rename_rule {
1424 rule.apply(&variant_name_str)
1425 } else {
1426 variant_name_str.clone()
1427 };
1428
1429 match &variant.fields {
1431 syn::Fields::Unit => {
1432 if prompt_attrs.skip {
1434 match_arms.push(quote! {
1435 Self::#variant_name => stringify!(#variant_name).to_string()
1436 });
1437 } else if let Some(desc) = &prompt_attrs.description {
1438 match_arms.push(quote! {
1439 Self::#variant_name => format!("{}: {}", #variant_value, #desc)
1440 });
1441 } else {
1442 let variant_docs = extract_doc_comments(&variant.attrs);
1443 if !variant_docs.is_empty() {
1444 match_arms.push(quote! {
1445 Self::#variant_name => format!("{}: {}", #variant_value, #variant_docs)
1446 });
1447 } else {
1448 match_arms.push(quote! {
1449 Self::#variant_name => #variant_value.to_string()
1450 });
1451 }
1452 }
1453 }
1454 syn::Fields::Named(fields) => {
1455 let field_bindings: Vec<_> = fields
1457 .named
1458 .iter()
1459 .map(|f| f.ident.as_ref().unwrap())
1460 .collect();
1461
1462 let field_displays: Vec<_> = fields
1463 .named
1464 .iter()
1465 .map(|f| {
1466 let field_name = f.ident.as_ref().unwrap();
1467 let field_name_str = field_name.to_string();
1468 quote! {
1469 format!("{}: {:?}", #field_name_str, #field_name)
1470 }
1471 })
1472 .collect();
1473
1474 let doc_or_desc = if let Some(desc) = &prompt_attrs.description {
1475 desc.clone()
1476 } else {
1477 let docs = extract_doc_comments(&variant.attrs);
1478 if !docs.is_empty() {
1479 docs
1480 } else {
1481 String::new()
1482 }
1483 };
1484
1485 if doc_or_desc.is_empty() {
1486 match_arms.push(quote! {
1487 Self::#variant_name { #(#field_bindings),* } => {
1488 let fields = vec![#(#field_displays),*];
1489 format!("{} {{ {} }}", #variant_value, fields.join(", "))
1490 }
1491 });
1492 } else {
1493 match_arms.push(quote! {
1494 Self::#variant_name { #(#field_bindings),* } => {
1495 let fields = vec![#(#field_displays),*];
1496 format!("{}: {} {{ {} }}", #variant_value, #doc_or_desc, fields.join(", "))
1497 }
1498 });
1499 }
1500 }
1501 syn::Fields::Unnamed(fields) => {
1502 let field_count = fields.unnamed.len();
1504 let field_bindings: Vec<_> = (0..field_count)
1505 .map(|i| {
1506 syn::Ident::new(
1507 &format!("field{}", i),
1508 proc_macro2::Span::call_site(),
1509 )
1510 })
1511 .collect();
1512
1513 let field_displays: Vec<_> = field_bindings
1514 .iter()
1515 .map(|field_name| {
1516 quote! {
1517 format!("{:?}", #field_name)
1518 }
1519 })
1520 .collect();
1521
1522 let doc_or_desc = if let Some(desc) = &prompt_attrs.description {
1523 desc.clone()
1524 } else {
1525 let docs = extract_doc_comments(&variant.attrs);
1526 if !docs.is_empty() {
1527 docs
1528 } else {
1529 String::new()
1530 }
1531 };
1532
1533 if doc_or_desc.is_empty() {
1534 match_arms.push(quote! {
1535 Self::#variant_name(#(#field_bindings),*) => {
1536 let fields = vec![#(#field_displays),*];
1537 format!("{}({})", #variant_value, fields.join(", "))
1538 }
1539 });
1540 } else {
1541 match_arms.push(quote! {
1542 Self::#variant_name(#(#field_bindings),*) => {
1543 let fields = vec![#(#field_displays),*];
1544 format!("{}: {}({})", #variant_value, #doc_or_desc, fields.join(", "))
1545 }
1546 });
1547 }
1548 }
1549 }
1550 }
1551
1552 let to_prompt_impl = if match_arms.is_empty() {
1553 quote! {
1555 fn to_prompt(&self) -> String {
1556 match *self {}
1557 }
1558 }
1559 } else {
1560 quote! {
1561 fn to_prompt(&self) -> String {
1562 match self {
1563 #(#match_arms),*
1564 }
1565 }
1566 }
1567 };
1568
1569 let expanded = quote! {
1570 impl #impl_generics #crate_path::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
1571 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1572 vec![#crate_path::prompt::PromptPart::Text(self.to_prompt())]
1573 }
1574
1575 #to_prompt_impl
1576
1577 fn prompt_schema() -> String {
1578 #prompt_string
1579 }
1580 }
1581 };
1582
1583 TokenStream::from(expanded)
1584 }
1585 Data::Struct(data_struct) => {
1586 let mut template_attr = None;
1588 let mut template_file_attr = None;
1589 let mut mode_attr = None;
1590 let mut validate_attr = false;
1591 let mut type_marker_attr = false;
1592
1593 for attr in &input.attrs {
1594 if attr.path().is_ident("prompt") {
1595 if let Ok(metas) =
1597 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1598 {
1599 for meta in metas {
1600 match meta {
1601 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1602 if let syn::Expr::Lit(expr_lit) = nv.value
1603 && let syn::Lit::Str(lit_str) = expr_lit.lit
1604 {
1605 template_attr = Some(lit_str.value());
1606 }
1607 }
1608 Meta::NameValue(nv) if nv.path.is_ident("template_file") => {
1609 if let syn::Expr::Lit(expr_lit) = nv.value
1610 && let syn::Lit::Str(lit_str) = expr_lit.lit
1611 {
1612 template_file_attr = Some(lit_str.value());
1613 }
1614 }
1615 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
1616 if let syn::Expr::Lit(expr_lit) = nv.value
1617 && let syn::Lit::Str(lit_str) = expr_lit.lit
1618 {
1619 mode_attr = Some(lit_str.value());
1620 }
1621 }
1622 Meta::NameValue(nv) if nv.path.is_ident("validate") => {
1623 if let syn::Expr::Lit(expr_lit) = nv.value
1624 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
1625 {
1626 validate_attr = lit_bool.value();
1627 }
1628 }
1629 Meta::NameValue(nv) if nv.path.is_ident("type_marker") => {
1630 if let syn::Expr::Lit(expr_lit) = nv.value
1631 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
1632 {
1633 type_marker_attr = lit_bool.value();
1634 }
1635 }
1636 Meta::Path(path) if path.is_ident("type_marker") => {
1637 type_marker_attr = true;
1639 }
1640 _ => {}
1641 }
1642 }
1643 }
1644 }
1645 }
1646
1647 if template_attr.is_some() && template_file_attr.is_some() {
1649 return syn::Error::new(
1650 input.ident.span(),
1651 "The `template` and `template_file` attributes are mutually exclusive. Please use only one.",
1652 ).to_compile_error().into();
1653 }
1654
1655 let template_str = if let Some(file_path) = template_file_attr {
1657 let mut full_path = None;
1661
1662 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
1664 let is_trybuild = manifest_dir.contains("target/tests/trybuild");
1666
1667 if !is_trybuild {
1668 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
1670 if candidate.exists() {
1671 full_path = Some(candidate);
1672 }
1673 } else {
1674 if let Some(target_pos) = manifest_dir.find("/target/tests/trybuild") {
1680 let workspace_root = &manifest_dir[..target_pos];
1681 let original_macros_dir = std::path::Path::new(workspace_root)
1683 .join("crates")
1684 .join("llm-toolkit-macros");
1685
1686 let candidate = original_macros_dir.join(&file_path);
1687 if candidate.exists() {
1688 full_path = Some(candidate);
1689 }
1690 }
1691 }
1692 }
1693
1694 if full_path.is_none() {
1696 let candidate = std::path::Path::new(&file_path).to_path_buf();
1697 if candidate.exists() {
1698 full_path = Some(candidate);
1699 }
1700 }
1701
1702 if full_path.is_none()
1705 && let Ok(current_dir) = std::env::current_dir()
1706 {
1707 let mut search_dir = current_dir.as_path();
1708 for _ in 0..10 {
1710 let macros_dir = search_dir.join("crates/llm-toolkit-macros");
1712 if macros_dir.exists() {
1713 let candidate = macros_dir.join(&file_path);
1714 if candidate.exists() {
1715 full_path = Some(candidate);
1716 break;
1717 }
1718 }
1719 let candidate = search_dir.join(&file_path);
1721 if candidate.exists() {
1722 full_path = Some(candidate);
1723 break;
1724 }
1725 if let Some(parent) = search_dir.parent() {
1726 search_dir = parent;
1727 } else {
1728 break;
1729 }
1730 }
1731 }
1732
1733 if full_path.is_none() {
1735 let mut error_msg = format!(
1737 "Template file '{}' not found at compile time.\n\nSearched in:",
1738 file_path
1739 );
1740
1741 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
1742 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
1743 error_msg.push_str(&format!("\n - {}", candidate.display()));
1744 }
1745
1746 if let Ok(current_dir) = std::env::current_dir() {
1747 let candidate = current_dir.join(&file_path);
1748 error_msg.push_str(&format!("\n - {}", candidate.display()));
1749 }
1750
1751 error_msg.push_str("\n\nPlease ensure:");
1752 error_msg.push_str("\n 1. The template file exists");
1753 error_msg.push_str("\n 2. The path is relative to CARGO_MANIFEST_DIR");
1754 error_msg.push_str("\n 3. There are no typos in the path");
1755
1756 return syn::Error::new(input.ident.span(), error_msg)
1757 .to_compile_error()
1758 .into();
1759 }
1760
1761 let final_path = full_path.unwrap();
1762
1763 match std::fs::read_to_string(&final_path) {
1765 Ok(content) => Some(content),
1766 Err(e) => {
1767 return syn::Error::new(
1768 input.ident.span(),
1769 format!(
1770 "Failed to read template file '{}': {}\n\nPath resolved to: {}",
1771 file_path,
1772 e,
1773 final_path.display()
1774 ),
1775 )
1776 .to_compile_error()
1777 .into();
1778 }
1779 }
1780 } else {
1781 template_attr
1782 };
1783
1784 if validate_attr && let Some(template) = &template_str {
1786 let mut env = minijinja::Environment::new();
1788 if let Err(e) = env.add_template("validation", template) {
1789 let warning_msg =
1791 format!("Template validation warning: Invalid Jinja syntax - {}", e);
1792 let warning_ident = syn::Ident::new(
1793 "TEMPLATE_VALIDATION_WARNING",
1794 proc_macro2::Span::call_site(),
1795 );
1796 let _warning_tokens = quote! {
1797 #[deprecated(note = #warning_msg)]
1798 const #warning_ident: () = ();
1799 let _ = #warning_ident;
1800 };
1801 eprintln!("cargo:warning={}", warning_msg);
1803 }
1804
1805 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1807 &fields.named
1808 } else {
1809 panic!("Template validation is only supported for structs with named fields.");
1810 };
1811
1812 let field_names: std::collections::HashSet<String> = fields
1813 .iter()
1814 .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
1815 .collect();
1816
1817 let placeholders = parse_template_placeholders_with_mode(template);
1819
1820 for (placeholder_name, _mode) in &placeholders {
1821 if placeholder_name != "self" && !field_names.contains(placeholder_name) {
1822 let warning_msg = format!(
1823 "Template validation warning: Variable '{}' used in template but not found in struct fields",
1824 placeholder_name
1825 );
1826 eprintln!("cargo:warning={}", warning_msg);
1827 }
1828 }
1829 }
1830
1831 let name = input.ident;
1832 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1833
1834 let struct_docs = extract_doc_comments(&input.attrs);
1836
1837 let is_mode_based =
1839 mode_attr.is_some() || (template_str.is_none() && struct_docs.contains("mode"));
1840
1841 let expanded = if is_mode_based || mode_attr.is_some() {
1842 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1844 &fields.named
1845 } else {
1846 panic!(
1847 "Mode-based prompt generation is only supported for structs with named fields."
1848 );
1849 };
1850
1851 let struct_name_str = name.to_string();
1852
1853 let has_default = input.attrs.iter().any(|attr| {
1855 if attr.path().is_ident("derive")
1856 && let Ok(meta_list) = attr.meta.require_list()
1857 {
1858 let tokens_str = meta_list.tokens.to_string();
1859 tokens_str.contains("Default")
1860 } else {
1861 false
1862 }
1863 });
1864
1865 let schema_parts = generate_schema_only_parts(
1876 &struct_name_str,
1877 &struct_docs,
1878 fields,
1879 &crate_path,
1880 type_marker_attr,
1881 );
1882
1883 let example_parts = generate_example_only_parts(fields, has_default, &crate_path);
1885
1886 quote! {
1887 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1888 fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<#crate_path::prompt::PromptPart> {
1889 match mode {
1890 "schema_only" => #schema_parts,
1891 "example_only" => #example_parts,
1892 "full" | _ => {
1893 let mut parts = Vec::new();
1895
1896 let schema_parts = #schema_parts;
1898 parts.extend(schema_parts);
1899
1900 parts.push(#crate_path::prompt::PromptPart::Text("\n### Example".to_string()));
1902 parts.push(#crate_path::prompt::PromptPart::Text(
1903 format!("Here is an example of a valid `{}` object:", #struct_name_str)
1904 ));
1905
1906 let example_parts = #example_parts;
1908 parts.extend(example_parts);
1909
1910 parts
1911 }
1912 }
1913 }
1914
1915 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1916 self.to_prompt_parts_with_mode("full")
1917 }
1918
1919 fn to_prompt(&self) -> String {
1920 self.to_prompt_parts()
1921 .into_iter()
1922 .filter_map(|part| match part {
1923 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1924 _ => None,
1925 })
1926 .collect::<Vec<_>>()
1927 .join("\n")
1928 }
1929
1930 fn prompt_schema() -> String {
1931 use std::sync::OnceLock;
1932 static SCHEMA_CACHE: OnceLock<String> = OnceLock::new();
1933
1934 SCHEMA_CACHE.get_or_init(|| {
1935 let schema_parts = #schema_parts;
1936 schema_parts
1937 .into_iter()
1938 .filter_map(|part| match part {
1939 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1940 _ => None,
1941 })
1942 .collect::<Vec<_>>()
1943 .join("\n")
1944 }).clone()
1945 }
1946 }
1947 }
1948 } else if let Some(template) = template_str {
1949 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1952 &fields.named
1953 } else {
1954 panic!(
1955 "Template prompt generation is only supported for structs with named fields."
1956 );
1957 };
1958
1959 let placeholders = parse_template_placeholders_with_mode(&template);
1961 let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
1963 mode.is_some()
1964 && fields
1965 .iter()
1966 .any(|f| f.ident.as_ref().unwrap() == field_name)
1967 });
1968
1969 let mut image_field_parts = Vec::new();
1970 for f in fields.iter() {
1971 let field_name = f.ident.as_ref().unwrap();
1972 let attrs = parse_field_prompt_attrs(&f.attrs);
1973
1974 if attrs.image {
1975 image_field_parts.push(quote! {
1977 parts.extend(self.#field_name.to_prompt_parts());
1978 });
1979 }
1980 }
1981
1982 if has_mode_syntax {
1984 let mut context_fields = Vec::new();
1986 let mut modified_template = template.clone();
1987
1988 for (field_name, mode_opt) in &placeholders {
1990 if let Some(mode) = mode_opt {
1991 let unique_key = format!("{}__{}", field_name, mode);
1993
1994 let pattern = format!("{{{{ {}:{} }}}}", field_name, mode);
1996 let replacement = format!("{{{{ {} }}}}", unique_key);
1997 modified_template = modified_template.replace(&pattern, &replacement);
1998
1999 let field_ident =
2001 syn::Ident::new(field_name, proc_macro2::Span::call_site());
2002
2003 context_fields.push(quote! {
2005 context.insert(
2006 #unique_key.to_string(),
2007 minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
2008 );
2009 });
2010 }
2011 }
2012
2013 for field in fields.iter() {
2015 let field_name = field.ident.as_ref().unwrap();
2016 let field_name_str = field_name.to_string();
2017
2018 let has_mode_entry = placeholders
2020 .iter()
2021 .any(|(name, mode)| name == &field_name_str && mode.is_some());
2022
2023 if !has_mode_entry {
2024 let is_primitive = match &field.ty {
2027 syn::Type::Path(type_path) => {
2028 if let Some(segment) = type_path.path.segments.last() {
2029 let type_name = segment.ident.to_string();
2030 matches!(
2031 type_name.as_str(),
2032 "String"
2033 | "str"
2034 | "i8"
2035 | "i16"
2036 | "i32"
2037 | "i64"
2038 | "i128"
2039 | "isize"
2040 | "u8"
2041 | "u16"
2042 | "u32"
2043 | "u64"
2044 | "u128"
2045 | "usize"
2046 | "f32"
2047 | "f64"
2048 | "bool"
2049 | "char"
2050 )
2051 } else {
2052 false
2053 }
2054 }
2055 _ => false,
2056 };
2057
2058 if is_primitive {
2059 context_fields.push(quote! {
2060 context.insert(
2061 #field_name_str.to_string(),
2062 minijinja::Value::from_serialize(&self.#field_name)
2063 );
2064 });
2065 } else {
2066 context_fields.push(quote! {
2068 context.insert(
2069 #field_name_str.to_string(),
2070 minijinja::Value::from(self.#field_name.to_prompt())
2071 );
2072 });
2073 }
2074 }
2075 }
2076
2077 quote! {
2078 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
2079 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
2080 let mut parts = Vec::new();
2081
2082 #(#image_field_parts)*
2084
2085 let text = {
2087 let mut env = minijinja::Environment::new();
2088 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
2089 panic!("Failed to parse template: {}", e)
2090 });
2091
2092 let tmpl = env.get_template("prompt").unwrap();
2093
2094 let mut context = std::collections::HashMap::new();
2095 #(#context_fields)*
2096
2097 tmpl.render(context).unwrap_or_else(|e| {
2098 format!("Failed to render prompt: {}", e)
2099 })
2100 };
2101
2102 if !text.is_empty() {
2103 parts.push(#crate_path::prompt::PromptPart::Text(text));
2104 }
2105
2106 parts
2107 }
2108
2109 fn to_prompt(&self) -> String {
2110 let mut env = minijinja::Environment::new();
2112 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
2113 panic!("Failed to parse template: {}", e)
2114 });
2115
2116 let tmpl = env.get_template("prompt").unwrap();
2117
2118 let mut context = std::collections::HashMap::new();
2119 #(#context_fields)*
2120
2121 tmpl.render(context).unwrap_or_else(|e| {
2122 format!("Failed to render prompt: {}", e)
2123 })
2124 }
2125
2126 fn prompt_schema() -> String {
2127 String::new() }
2129 }
2130 }
2131 } else {
2132 let mut simple_context_fields = Vec::new();
2137 for field in fields.iter() {
2138 let field_name = field.ident.as_ref().unwrap();
2139 let field_name_str = field_name.to_string();
2140
2141 let is_primitive = match &field.ty {
2143 syn::Type::Path(type_path) => {
2144 if let Some(segment) = type_path.path.segments.last() {
2145 let type_name = segment.ident.to_string();
2146 matches!(
2147 type_name.as_str(),
2148 "String"
2149 | "str"
2150 | "i8"
2151 | "i16"
2152 | "i32"
2153 | "i64"
2154 | "i128"
2155 | "isize"
2156 | "u8"
2157 | "u16"
2158 | "u32"
2159 | "u64"
2160 | "u128"
2161 | "usize"
2162 | "f32"
2163 | "f64"
2164 | "bool"
2165 | "char"
2166 )
2167 } else {
2168 false
2169 }
2170 }
2171 _ => false,
2172 };
2173
2174 if is_primitive {
2175 simple_context_fields.push(quote! {
2176 context.insert(
2177 #field_name_str.to_string(),
2178 minijinja::Value::from_serialize(&self.#field_name)
2179 );
2180 });
2181 } else {
2182 simple_context_fields.push(quote! {
2184 context.insert(
2185 #field_name_str.to_string(),
2186 minijinja::Value::from(self.#field_name.to_prompt())
2187 );
2188 });
2189 }
2190 }
2191
2192 quote! {
2193 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
2194 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
2195 let mut parts = Vec::new();
2196
2197 #(#image_field_parts)*
2199
2200 let text = {
2202 let mut env = minijinja::Environment::new();
2203 env.add_template("prompt", #template).unwrap_or_else(|e| {
2204 panic!("Failed to parse template: {}", e)
2205 });
2206
2207 let tmpl = env.get_template("prompt").unwrap();
2208
2209 let mut context = std::collections::HashMap::new();
2210 #(#simple_context_fields)*
2211
2212 tmpl.render(context).unwrap_or_else(|e| {
2213 format!("Failed to render prompt: {}", e)
2214 })
2215 };
2216
2217 if !text.is_empty() {
2218 parts.push(#crate_path::prompt::PromptPart::Text(text));
2219 }
2220
2221 parts
2222 }
2223
2224 fn to_prompt(&self) -> String {
2225 let mut env = minijinja::Environment::new();
2227 env.add_template("prompt", #template).unwrap_or_else(|e| {
2228 panic!("Failed to parse template: {}", e)
2229 });
2230
2231 let tmpl = env.get_template("prompt").unwrap();
2232
2233 let mut context = std::collections::HashMap::new();
2234 #(#simple_context_fields)*
2235
2236 tmpl.render(context).unwrap_or_else(|e| {
2237 format!("Failed to render prompt: {}", e)
2238 })
2239 }
2240
2241 fn prompt_schema() -> String {
2242 String::new() }
2244 }
2245 }
2246 }
2247 } else {
2248 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
2251 &fields.named
2252 } else {
2253 panic!(
2254 "Default prompt generation is only supported for structs with named fields."
2255 );
2256 };
2257
2258 let mut text_field_parts = Vec::new();
2260 let mut image_field_parts = Vec::new();
2261
2262 for f in fields.iter() {
2263 let field_name = f.ident.as_ref().unwrap();
2264 let attrs = parse_field_prompt_attrs(&f.attrs);
2265
2266 if attrs.skip {
2268 continue;
2269 }
2270
2271 if attrs.image {
2272 image_field_parts.push(quote! {
2274 parts.extend(self.#field_name.to_prompt_parts());
2275 });
2276 } else {
2277 let key = if let Some(rename) = attrs.rename {
2283 rename
2284 } else {
2285 let doc_comment = extract_doc_comments(&f.attrs);
2286 if !doc_comment.is_empty() {
2287 doc_comment
2288 } else {
2289 field_name.to_string()
2290 }
2291 };
2292
2293 let value_expr = if let Some(format_with) = attrs.format_with {
2295 let func_path: syn::Path =
2297 syn::parse_str(&format_with).unwrap_or_else(|_| {
2298 panic!("Invalid function path: {}", format_with)
2299 });
2300 quote! { #func_path(&self.#field_name) }
2301 } else {
2302 quote! { self.#field_name.to_prompt() }
2303 };
2304
2305 text_field_parts.push(quote! {
2306 text_parts.push(format!("{}: {}", #key, #value_expr));
2307 });
2308 }
2309 }
2310
2311 let struct_name_str = name.to_string();
2313 let schema_parts = generate_schema_only_parts(
2314 &struct_name_str,
2315 &struct_docs,
2316 fields,
2317 &crate_path,
2318 false, );
2320
2321 quote! {
2323 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
2324 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
2325 let mut parts = Vec::new();
2326
2327 #(#image_field_parts)*
2329
2330 let mut text_parts = Vec::new();
2332 #(#text_field_parts)*
2333
2334 if !text_parts.is_empty() {
2335 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
2336 }
2337
2338 parts
2339 }
2340
2341 fn to_prompt(&self) -> String {
2342 let mut text_parts = Vec::new();
2343 #(#text_field_parts)*
2344 text_parts.join("\n")
2345 }
2346
2347 fn prompt_schema() -> String {
2348 use std::sync::OnceLock;
2349 static SCHEMA_CACHE: OnceLock<String> = OnceLock::new();
2350
2351 SCHEMA_CACHE.get_or_init(|| {
2352 let schema_parts = #schema_parts;
2353 schema_parts
2354 .into_iter()
2355 .filter_map(|part| match part {
2356 #crate_path::prompt::PromptPart::Text(text) => Some(text),
2357 _ => None,
2358 })
2359 .collect::<Vec<_>>()
2360 .join("\n")
2361 }).clone()
2362 }
2363 }
2364 }
2365 };
2366
2367 TokenStream::from(expanded)
2368 }
2369 Data::Union(_) => {
2370 panic!("`#[derive(ToPrompt)]` is not supported for unions");
2371 }
2372 }
2373}
2374
2375#[derive(Debug, Clone)]
2377struct TargetInfo {
2378 name: String,
2379 template: Option<String>,
2380 field_configs: std::collections::HashMap<String, FieldTargetConfig>,
2381}
2382
2383#[derive(Debug, Clone, Default)]
2385struct FieldTargetConfig {
2386 skip: bool,
2387 rename: Option<String>,
2388 format_with: Option<String>,
2389 image: bool,
2390 include_only: bool, }
2392
2393fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
2395 let mut configs = Vec::new();
2396
2397 for attr in attrs {
2398 if attr.path().is_ident("prompt_for")
2399 && let Ok(meta_list) = attr.meta.require_list()
2400 {
2401 if meta_list.tokens.to_string() == "skip" {
2403 let config = FieldTargetConfig {
2405 skip: true,
2406 ..Default::default()
2407 };
2408 configs.push(("*".to_string(), config));
2409 } else if let Ok(metas) =
2410 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2411 {
2412 let mut target_name = None;
2413 let mut config = FieldTargetConfig::default();
2414
2415 for meta in metas {
2416 match meta {
2417 Meta::NameValue(nv) if nv.path.is_ident("name") => {
2418 if let syn::Expr::Lit(syn::ExprLit {
2419 lit: syn::Lit::Str(lit_str),
2420 ..
2421 }) = nv.value
2422 {
2423 target_name = Some(lit_str.value());
2424 }
2425 }
2426 Meta::Path(path) if path.is_ident("skip") => {
2427 config.skip = true;
2428 }
2429 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
2430 if let syn::Expr::Lit(syn::ExprLit {
2431 lit: syn::Lit::Str(lit_str),
2432 ..
2433 }) = nv.value
2434 {
2435 config.rename = Some(lit_str.value());
2436 }
2437 }
2438 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
2439 if let syn::Expr::Lit(syn::ExprLit {
2440 lit: syn::Lit::Str(lit_str),
2441 ..
2442 }) = nv.value
2443 {
2444 config.format_with = Some(lit_str.value());
2445 }
2446 }
2447 Meta::Path(path) if path.is_ident("image") => {
2448 config.image = true;
2449 }
2450 _ => {}
2451 }
2452 }
2453
2454 if let Some(name) = target_name {
2455 config.include_only = true;
2456 configs.push((name, config));
2457 }
2458 }
2459 }
2460 }
2461
2462 configs
2463}
2464
2465fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
2467 let mut targets = Vec::new();
2468
2469 for attr in attrs {
2470 if attr.path().is_ident("prompt_for")
2471 && let Ok(meta_list) = attr.meta.require_list()
2472 && let Ok(metas) =
2473 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2474 {
2475 let mut target_name = None;
2476 let mut template = None;
2477
2478 for meta in metas {
2479 match meta {
2480 Meta::NameValue(nv) if nv.path.is_ident("name") => {
2481 if let syn::Expr::Lit(syn::ExprLit {
2482 lit: syn::Lit::Str(lit_str),
2483 ..
2484 }) = nv.value
2485 {
2486 target_name = Some(lit_str.value());
2487 }
2488 }
2489 Meta::NameValue(nv) if nv.path.is_ident("template") => {
2490 if let syn::Expr::Lit(syn::ExprLit {
2491 lit: syn::Lit::Str(lit_str),
2492 ..
2493 }) = nv.value
2494 {
2495 template = Some(lit_str.value());
2496 }
2497 }
2498 _ => {}
2499 }
2500 }
2501
2502 if let Some(name) = target_name {
2503 targets.push(TargetInfo {
2504 name,
2505 template,
2506 field_configs: std::collections::HashMap::new(),
2507 });
2508 }
2509 }
2510 }
2511
2512 targets
2513}
2514
2515#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
2516pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
2517 let input = parse_macro_input!(input as DeriveInput);
2518
2519 let found_crate =
2520 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2521 let crate_path = match found_crate {
2522 FoundCrate::Itself => {
2523 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2525 quote!(::#ident)
2526 }
2527 FoundCrate::Name(name) => {
2528 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2529 quote!(::#ident)
2530 }
2531 };
2532
2533 let data_struct = match &input.data {
2535 Data::Struct(data) => data,
2536 _ => {
2537 return syn::Error::new(
2538 input.ident.span(),
2539 "`#[derive(ToPromptSet)]` is only supported for structs",
2540 )
2541 .to_compile_error()
2542 .into();
2543 }
2544 };
2545
2546 let fields = match &data_struct.fields {
2547 syn::Fields::Named(fields) => &fields.named,
2548 _ => {
2549 return syn::Error::new(
2550 input.ident.span(),
2551 "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
2552 )
2553 .to_compile_error()
2554 .into();
2555 }
2556 };
2557
2558 let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
2560
2561 for field in fields.iter() {
2563 let field_name = field.ident.as_ref().unwrap().to_string();
2564 let field_configs = parse_prompt_for_attrs(&field.attrs);
2565
2566 for (target_name, config) in field_configs {
2567 if target_name == "*" {
2568 for target in &mut targets {
2570 target
2571 .field_configs
2572 .entry(field_name.clone())
2573 .or_insert_with(FieldTargetConfig::default)
2574 .skip = config.skip;
2575 }
2576 } else {
2577 let target_exists = targets.iter().any(|t| t.name == target_name);
2579 if !target_exists {
2580 targets.push(TargetInfo {
2582 name: target_name.clone(),
2583 template: None,
2584 field_configs: std::collections::HashMap::new(),
2585 });
2586 }
2587
2588 let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
2589
2590 target.field_configs.insert(field_name.clone(), config);
2591 }
2592 }
2593 }
2594
2595 let mut match_arms = Vec::new();
2597
2598 for target in &targets {
2599 let target_name = &target.name;
2600
2601 if let Some(template_str) = &target.template {
2602 let mut image_parts = Vec::new();
2604
2605 for field in fields.iter() {
2606 let field_name = field.ident.as_ref().unwrap();
2607 let field_name_str = field_name.to_string();
2608
2609 if let Some(config) = target.field_configs.get(&field_name_str)
2610 && config.image
2611 {
2612 image_parts.push(quote! {
2613 parts.extend(self.#field_name.to_prompt_parts());
2614 });
2615 }
2616 }
2617
2618 match_arms.push(quote! {
2619 #target_name => {
2620 let mut parts = Vec::new();
2621
2622 #(#image_parts)*
2623
2624 let text = #crate_path::prompt::render_prompt(#template_str, self)
2625 .map_err(|e| #crate_path::prompt::PromptSetError::RenderFailed {
2626 target: #target_name.to_string(),
2627 source: e,
2628 })?;
2629
2630 if !text.is_empty() {
2631 parts.push(#crate_path::prompt::PromptPart::Text(text));
2632 }
2633
2634 Ok(parts)
2635 }
2636 });
2637 } else {
2638 let mut text_field_parts = Vec::new();
2640 let mut image_field_parts = Vec::new();
2641
2642 for field in fields.iter() {
2643 let field_name = field.ident.as_ref().unwrap();
2644 let field_name_str = field_name.to_string();
2645
2646 let config = target.field_configs.get(&field_name_str);
2648
2649 if let Some(cfg) = config
2651 && cfg.skip
2652 {
2653 continue;
2654 }
2655
2656 let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
2660 let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
2661 .iter()
2662 .any(|(name, _)| name != "*");
2663
2664 if has_any_target_specific_config && !is_explicitly_for_this_target {
2665 continue;
2666 }
2667
2668 if let Some(cfg) = config {
2669 if cfg.image {
2670 image_field_parts.push(quote! {
2671 parts.extend(self.#field_name.to_prompt_parts());
2672 });
2673 } else {
2674 let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
2675
2676 let value_expr = if let Some(format_with) = &cfg.format_with {
2677 match syn::parse_str::<syn::Path>(format_with) {
2679 Ok(func_path) => quote! { #func_path(&self.#field_name) },
2680 Err(_) => {
2681 let error_msg = format!(
2683 "Invalid function path in format_with: '{}'",
2684 format_with
2685 );
2686 quote! {
2687 compile_error!(#error_msg);
2688 String::new()
2689 }
2690 }
2691 }
2692 } else {
2693 quote! { self.#field_name.to_prompt() }
2694 };
2695
2696 text_field_parts.push(quote! {
2697 text_parts.push(format!("{}: {}", #key, #value_expr));
2698 });
2699 }
2700 } else {
2701 text_field_parts.push(quote! {
2703 text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
2704 });
2705 }
2706 }
2707
2708 match_arms.push(quote! {
2709 #target_name => {
2710 let mut parts = Vec::new();
2711
2712 #(#image_field_parts)*
2713
2714 let mut text_parts = Vec::new();
2715 #(#text_field_parts)*
2716
2717 if !text_parts.is_empty() {
2718 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
2719 }
2720
2721 Ok(parts)
2722 }
2723 });
2724 }
2725 }
2726
2727 let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
2729
2730 match_arms.push(quote! {
2732 _ => {
2733 let available = vec![#(#target_names.to_string()),*];
2734 Err(#crate_path::prompt::PromptSetError::TargetNotFound {
2735 target: target.to_string(),
2736 available,
2737 })
2738 }
2739 });
2740
2741 let struct_name = &input.ident;
2742 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2743
2744 let expanded = quote! {
2745 impl #impl_generics #crate_path::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
2746 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<#crate_path::prompt::PromptPart>, #crate_path::prompt::PromptSetError> {
2747 match target {
2748 #(#match_arms)*
2749 }
2750 }
2751 }
2752 };
2753
2754 TokenStream::from(expanded)
2755}
2756
2757struct TypeList {
2759 types: Punctuated<syn::Type, Token![,]>,
2760}
2761
2762impl Parse for TypeList {
2763 fn parse(input: ParseStream) -> syn::Result<Self> {
2764 Ok(TypeList {
2765 types: Punctuated::parse_terminated(input)?,
2766 })
2767 }
2768}
2769
2770#[proc_macro]
2794pub fn examples_section(input: TokenStream) -> TokenStream {
2795 let input = parse_macro_input!(input as TypeList);
2796
2797 let found_crate =
2798 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2799 let _crate_path = match found_crate {
2800 FoundCrate::Itself => quote!(crate),
2801 FoundCrate::Name(name) => {
2802 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2803 quote!(::#ident)
2804 }
2805 };
2806
2807 let mut type_sections = Vec::new();
2809
2810 for ty in input.types.iter() {
2811 let type_name_str = quote!(#ty).to_string();
2813
2814 type_sections.push(quote! {
2816 {
2817 let type_name = #type_name_str;
2818 let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
2819 format!("---\n#### `{}`\n{}", type_name, json_example)
2820 }
2821 });
2822 }
2823
2824 let expanded = quote! {
2826 {
2827 let mut sections = Vec::new();
2828 sections.push("---".to_string());
2829 sections.push("### Examples".to_string());
2830 sections.push("".to_string());
2831 sections.push("Here are examples of the data structures you should use.".to_string());
2832 sections.push("".to_string());
2833
2834 #(sections.push(#type_sections);)*
2835
2836 sections.push("---".to_string());
2837
2838 sections.join("\n")
2839 }
2840 };
2841
2842 TokenStream::from(expanded)
2843}
2844
2845fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
2847 for attr in attrs {
2848 if attr.path().is_ident("prompt_for")
2849 && let Ok(meta_list) = attr.meta.require_list()
2850 && let Ok(metas) =
2851 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2852 {
2853 let mut target_type = None;
2854 let mut template = None;
2855
2856 for meta in metas {
2857 match meta {
2858 Meta::NameValue(nv) if nv.path.is_ident("target") => {
2859 if let syn::Expr::Lit(syn::ExprLit {
2860 lit: syn::Lit::Str(lit_str),
2861 ..
2862 }) = nv.value
2863 {
2864 target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
2866 }
2867 }
2868 Meta::NameValue(nv) if nv.path.is_ident("template") => {
2869 if let syn::Expr::Lit(syn::ExprLit {
2870 lit: syn::Lit::Str(lit_str),
2871 ..
2872 }) = nv.value
2873 {
2874 template = Some(lit_str.value());
2875 }
2876 }
2877 _ => {}
2878 }
2879 }
2880
2881 if let (Some(target), Some(tmpl)) = (target_type, template) {
2882 return (target, tmpl);
2883 }
2884 }
2885 }
2886
2887 panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
2888}
2889
2890#[proc_macro_attribute]
2924pub fn define_intent(_attr: TokenStream, item: TokenStream) -> TokenStream {
2925 let input = parse_macro_input!(item as DeriveInput);
2926
2927 let found_crate =
2928 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2929 let crate_path = match found_crate {
2930 FoundCrate::Itself => {
2931 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2933 quote!(::#ident)
2934 }
2935 FoundCrate::Name(name) => {
2936 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2937 quote!(::#ident)
2938 }
2939 };
2940
2941 let enum_data = match &input.data {
2943 Data::Enum(data) => data,
2944 _ => {
2945 return syn::Error::new(
2946 input.ident.span(),
2947 "`#[define_intent]` can only be applied to enums",
2948 )
2949 .to_compile_error()
2950 .into();
2951 }
2952 };
2953
2954 let mut prompt_template = None;
2956 let mut extractor_tag = None;
2957 let mut mode = None;
2958
2959 for attr in &input.attrs {
2960 if attr.path().is_ident("intent")
2961 && let Ok(metas) =
2962 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2963 {
2964 for meta in metas {
2965 match meta {
2966 Meta::NameValue(nv) if nv.path.is_ident("prompt") => {
2967 if let syn::Expr::Lit(syn::ExprLit {
2968 lit: syn::Lit::Str(lit_str),
2969 ..
2970 }) = nv.value
2971 {
2972 prompt_template = Some(lit_str.value());
2973 }
2974 }
2975 Meta::NameValue(nv) if nv.path.is_ident("extractor_tag") => {
2976 if let syn::Expr::Lit(syn::ExprLit {
2977 lit: syn::Lit::Str(lit_str),
2978 ..
2979 }) = nv.value
2980 {
2981 extractor_tag = Some(lit_str.value());
2982 }
2983 }
2984 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
2985 if let syn::Expr::Lit(syn::ExprLit {
2986 lit: syn::Lit::Str(lit_str),
2987 ..
2988 }) = nv.value
2989 {
2990 mode = Some(lit_str.value());
2991 }
2992 }
2993 _ => {}
2994 }
2995 }
2996 }
2997 }
2998
2999 let mode = mode.unwrap_or_else(|| "single".to_string());
3001
3002 if mode != "single" && mode != "multi_tag" {
3004 return syn::Error::new(
3005 input.ident.span(),
3006 "`mode` must be either \"single\" or \"multi_tag\"",
3007 )
3008 .to_compile_error()
3009 .into();
3010 }
3011
3012 let prompt_template = match prompt_template {
3014 Some(p) => p,
3015 None => {
3016 return syn::Error::new(
3017 input.ident.span(),
3018 "`#[intent(...)]` attribute must include `prompt = \"...\"`",
3019 )
3020 .to_compile_error()
3021 .into();
3022 }
3023 };
3024
3025 if mode == "multi_tag" {
3027 let enum_name = &input.ident;
3028 let actions_doc = generate_multi_tag_actions_doc(&enum_data.variants);
3029 return generate_multi_tag_output(
3030 &input,
3031 enum_name,
3032 enum_data,
3033 prompt_template,
3034 actions_doc,
3035 );
3036 }
3037
3038 let extractor_tag = match extractor_tag {
3040 Some(t) => t,
3041 None => {
3042 return syn::Error::new(
3043 input.ident.span(),
3044 "`#[intent(...)]` attribute must include `extractor_tag = \"...\"`",
3045 )
3046 .to_compile_error()
3047 .into();
3048 }
3049 };
3050
3051 let enum_name = &input.ident;
3053 let enum_docs = extract_doc_comments(&input.attrs);
3054
3055 let mut intents_doc_lines = Vec::new();
3056
3057 if !enum_docs.is_empty() {
3059 intents_doc_lines.push(format!("{}: {}", enum_name, enum_docs));
3060 } else {
3061 intents_doc_lines.push(format!("{}:", enum_name));
3062 }
3063 intents_doc_lines.push(String::new()); intents_doc_lines.push("Possible values:".to_string());
3065
3066 for variant in &enum_data.variants {
3068 let variant_name = &variant.ident;
3069 let variant_docs = extract_doc_comments(&variant.attrs);
3070
3071 if !variant_docs.is_empty() {
3072 intents_doc_lines.push(format!("- {}: {}", variant_name, variant_docs));
3073 } else {
3074 intents_doc_lines.push(format!("- {}", variant_name));
3075 }
3076 }
3077
3078 let intents_doc_str = intents_doc_lines.join("\n");
3079
3080 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
3082 let user_variables: Vec<String> = placeholders
3083 .iter()
3084 .filter_map(|(name, _)| {
3085 if name != "intents_doc" {
3086 Some(name.clone())
3087 } else {
3088 None
3089 }
3090 })
3091 .collect();
3092
3093 let enum_name_str = enum_name.to_string();
3095 let snake_case_name = to_snake_case(&enum_name_str);
3096 let function_name = syn::Ident::new(
3097 &format!("build_{}_prompt", snake_case_name),
3098 proc_macro2::Span::call_site(),
3099 );
3100
3101 let function_params: Vec<proc_macro2::TokenStream> = user_variables
3103 .iter()
3104 .map(|var| {
3105 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
3106 quote! { #ident: &str }
3107 })
3108 .collect();
3109
3110 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
3112 .iter()
3113 .map(|var| {
3114 let var_str = var.clone();
3115 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
3116 quote! {
3117 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
3118 }
3119 })
3120 .collect();
3121
3122 let converted_template = prompt_template.clone();
3124
3125 let extractor_name = syn::Ident::new(
3127 &format!("{}Extractor", enum_name),
3128 proc_macro2::Span::call_site(),
3129 );
3130
3131 let filtered_attrs: Vec<_> = input
3133 .attrs
3134 .iter()
3135 .filter(|attr| !attr.path().is_ident("intent"))
3136 .collect();
3137
3138 let vis = &input.vis;
3140 let generics = &input.generics;
3141 let variants = &enum_data.variants;
3142 let enum_output = quote! {
3143 #(#filtered_attrs)*
3144 #vis enum #enum_name #generics {
3145 #variants
3146 }
3147 };
3148
3149 let expanded = quote! {
3151 #enum_output
3153
3154 pub fn #function_name(#(#function_params),*) -> String {
3156 let mut env = minijinja::Environment::new();
3157 env.add_template("prompt", #converted_template)
3158 .expect("Failed to parse intent prompt template");
3159
3160 let tmpl = env.get_template("prompt").unwrap();
3161
3162 let mut __template_context = std::collections::HashMap::new();
3163
3164 __template_context.insert("intents_doc".to_string(), minijinja::Value::from(#intents_doc_str));
3166
3167 #(#context_insertions)*
3169
3170 tmpl.render(&__template_context)
3171 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
3172 }
3173
3174 pub struct #extractor_name;
3176
3177 impl #extractor_name {
3178 pub const EXTRACTOR_TAG: &'static str = #extractor_tag;
3179 }
3180
3181 impl #crate_path::intent::IntentExtractor<#enum_name> for #extractor_name {
3182 fn extract_intent(&self, response: &str) -> Result<#enum_name, #crate_path::intent::IntentExtractionError> {
3183 #crate_path::intent::extract_intent_from_response(response, Self::EXTRACTOR_TAG)
3185 }
3186 }
3187 };
3188
3189 TokenStream::from(expanded)
3190}
3191
3192fn to_snake_case(s: &str) -> String {
3194 let mut result = String::new();
3195 let mut prev_upper = false;
3196
3197 for (i, ch) in s.chars().enumerate() {
3198 if ch.is_uppercase() {
3199 if i > 0 && !prev_upper {
3200 result.push('_');
3201 }
3202 result.push(ch.to_lowercase().next().unwrap());
3203 prev_upper = true;
3204 } else {
3205 result.push(ch);
3206 prev_upper = false;
3207 }
3208 }
3209
3210 result
3211}
3212
3213#[derive(Debug, Default)]
3215struct ActionAttrs {
3216 tag: Option<String>,
3217}
3218
3219fn parse_action_attrs(attrs: &[syn::Attribute]) -> ActionAttrs {
3220 let mut result = ActionAttrs::default();
3221
3222 for attr in attrs {
3223 if attr.path().is_ident("action")
3224 && let Ok(meta_list) = attr.meta.require_list()
3225 && let Ok(metas) =
3226 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
3227 {
3228 for meta in metas {
3229 if let Meta::NameValue(nv) = meta
3230 && nv.path.is_ident("tag")
3231 && let syn::Expr::Lit(syn::ExprLit {
3232 lit: syn::Lit::Str(lit_str),
3233 ..
3234 }) = nv.value
3235 {
3236 result.tag = Some(lit_str.value());
3237 }
3238 }
3239 }
3240 }
3241
3242 result
3243}
3244
3245#[derive(Debug, Default)]
3247struct FieldActionAttrs {
3248 is_attribute: bool,
3249 is_inner_text: bool,
3250}
3251
3252fn parse_field_action_attrs(attrs: &[syn::Attribute]) -> FieldActionAttrs {
3253 let mut result = FieldActionAttrs::default();
3254
3255 for attr in attrs {
3256 if attr.path().is_ident("action")
3257 && let Ok(meta_list) = attr.meta.require_list()
3258 {
3259 let tokens_str = meta_list.tokens.to_string();
3260 if tokens_str == "attribute" {
3261 result.is_attribute = true;
3262 } else if tokens_str == "inner_text" {
3263 result.is_inner_text = true;
3264 }
3265 }
3266 }
3267
3268 result
3269}
3270
3271fn generate_multi_tag_actions_doc(
3273 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
3274) -> String {
3275 let mut doc_lines = Vec::new();
3276
3277 for variant in variants {
3278 let action_attrs = parse_action_attrs(&variant.attrs);
3279
3280 if let Some(tag) = action_attrs.tag {
3281 let variant_docs = extract_doc_comments(&variant.attrs);
3282
3283 match &variant.fields {
3284 syn::Fields::Unit => {
3285 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
3287 }
3288 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
3289 doc_lines.push(format!("- `<{}>...</{}>`: {}", tag, tag, variant_docs));
3291 }
3292 syn::Fields::Named(fields) => {
3293 let mut attrs_str = Vec::new();
3295 let mut has_inner_text = false;
3296
3297 for field in &fields.named {
3298 let field_name = field.ident.as_ref().unwrap();
3299 let field_attrs = parse_field_action_attrs(&field.attrs);
3300
3301 if field_attrs.is_attribute {
3302 attrs_str.push(format!("{}=\"...\"", field_name));
3303 } else if field_attrs.is_inner_text {
3304 has_inner_text = true;
3305 }
3306 }
3307
3308 let attrs_part = if !attrs_str.is_empty() {
3309 format!(" {}", attrs_str.join(" "))
3310 } else {
3311 String::new()
3312 };
3313
3314 if has_inner_text {
3315 doc_lines.push(format!(
3316 "- `<{}{}>...</{}>`: {}",
3317 tag, attrs_part, tag, variant_docs
3318 ));
3319 } else if !attrs_str.is_empty() {
3320 doc_lines.push(format!("- `<{}{} />`: {}", tag, attrs_part, variant_docs));
3321 } else {
3322 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
3323 }
3324
3325 for field in &fields.named {
3327 let field_name = field.ident.as_ref().unwrap();
3328 let field_attrs = parse_field_action_attrs(&field.attrs);
3329 let field_docs = extract_doc_comments(&field.attrs);
3330
3331 if field_attrs.is_attribute {
3332 doc_lines
3333 .push(format!(" - `{}` (attribute): {}", field_name, field_docs));
3334 } else if field_attrs.is_inner_text {
3335 doc_lines
3336 .push(format!(" - `{}` (inner_text): {}", field_name, field_docs));
3337 }
3338 }
3339 }
3340 _ => {
3341 }
3343 }
3344 }
3345 }
3346
3347 doc_lines.join("\n")
3348}
3349
3350fn generate_tags_regex(
3352 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
3353) -> String {
3354 let mut tag_names = Vec::new();
3355
3356 for variant in variants {
3357 let action_attrs = parse_action_attrs(&variant.attrs);
3358 if let Some(tag) = action_attrs.tag {
3359 tag_names.push(tag);
3360 }
3361 }
3362
3363 if tag_names.is_empty() {
3364 return String::new();
3365 }
3366
3367 let tags_pattern = tag_names.join("|");
3368 format!(
3371 r"(?is)<(?:{})\b[^>]*/>|<(?:{})\b[^>]*>.*?</(?:{})>",
3372 tags_pattern, tags_pattern, tags_pattern
3373 )
3374}
3375
3376fn generate_multi_tag_output(
3378 input: &DeriveInput,
3379 enum_name: &syn::Ident,
3380 enum_data: &syn::DataEnum,
3381 prompt_template: String,
3382 actions_doc: String,
3383) -> TokenStream {
3384 let found_crate =
3385 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3386 let crate_path = match found_crate {
3387 FoundCrate::Itself => {
3388 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3390 quote!(::#ident)
3391 }
3392 FoundCrate::Name(name) => {
3393 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3394 quote!(::#ident)
3395 }
3396 };
3397
3398 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
3400 let user_variables: Vec<String> = placeholders
3401 .iter()
3402 .filter_map(|(name, _)| {
3403 if name != "actions_doc" {
3404 Some(name.clone())
3405 } else {
3406 None
3407 }
3408 })
3409 .collect();
3410
3411 let enum_name_str = enum_name.to_string();
3413 let snake_case_name = to_snake_case(&enum_name_str);
3414 let function_name = syn::Ident::new(
3415 &format!("build_{}_prompt", snake_case_name),
3416 proc_macro2::Span::call_site(),
3417 );
3418
3419 let function_params: Vec<proc_macro2::TokenStream> = user_variables
3421 .iter()
3422 .map(|var| {
3423 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
3424 quote! { #ident: &str }
3425 })
3426 .collect();
3427
3428 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
3430 .iter()
3431 .map(|var| {
3432 let var_str = var.clone();
3433 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
3434 quote! {
3435 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
3436 }
3437 })
3438 .collect();
3439
3440 let extractor_name = syn::Ident::new(
3442 &format!("{}Extractor", enum_name),
3443 proc_macro2::Span::call_site(),
3444 );
3445
3446 let filtered_attrs: Vec<_> = input
3448 .attrs
3449 .iter()
3450 .filter(|attr| !attr.path().is_ident("intent"))
3451 .collect();
3452
3453 let filtered_variants: Vec<proc_macro2::TokenStream> = enum_data
3455 .variants
3456 .iter()
3457 .map(|variant| {
3458 let variant_name = &variant.ident;
3459 let variant_attrs: Vec<_> = variant
3460 .attrs
3461 .iter()
3462 .filter(|attr| !attr.path().is_ident("action"))
3463 .collect();
3464 let fields = &variant.fields;
3465
3466 let filtered_fields = match fields {
3468 syn::Fields::Named(named_fields) => {
3469 let filtered: Vec<_> = named_fields
3470 .named
3471 .iter()
3472 .map(|field| {
3473 let field_name = &field.ident;
3474 let field_type = &field.ty;
3475 let field_vis = &field.vis;
3476 let filtered_attrs: Vec<_> = field
3477 .attrs
3478 .iter()
3479 .filter(|attr| !attr.path().is_ident("action"))
3480 .collect();
3481 quote! {
3482 #(#filtered_attrs)*
3483 #field_vis #field_name: #field_type
3484 }
3485 })
3486 .collect();
3487 quote! { { #(#filtered,)* } }
3488 }
3489 syn::Fields::Unnamed(unnamed_fields) => {
3490 let types: Vec<_> = unnamed_fields
3491 .unnamed
3492 .iter()
3493 .map(|field| {
3494 let field_type = &field.ty;
3495 quote! { #field_type }
3496 })
3497 .collect();
3498 quote! { (#(#types),*) }
3499 }
3500 syn::Fields::Unit => quote! {},
3501 };
3502
3503 quote! {
3504 #(#variant_attrs)*
3505 #variant_name #filtered_fields
3506 }
3507 })
3508 .collect();
3509
3510 let vis = &input.vis;
3511 let generics = &input.generics;
3512
3513 let parsing_arms = generate_parsing_arms(&enum_data.variants, enum_name);
3515
3516 let tags_regex = generate_tags_regex(&enum_data.variants);
3518
3519 let expanded = quote! {
3520 #(#filtered_attrs)*
3522 #vis enum #enum_name #generics {
3523 #(#filtered_variants),*
3524 }
3525
3526 pub fn #function_name(#(#function_params),*) -> String {
3528 let mut env = minijinja::Environment::new();
3529 env.add_template("prompt", #prompt_template)
3530 .expect("Failed to parse intent prompt template");
3531
3532 let tmpl = env.get_template("prompt").unwrap();
3533
3534 let mut __template_context = std::collections::HashMap::new();
3535
3536 __template_context.insert("actions_doc".to_string(), minijinja::Value::from(#actions_doc));
3538
3539 #(#context_insertions)*
3541
3542 tmpl.render(&__template_context)
3543 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
3544 }
3545
3546 pub struct #extractor_name;
3548
3549 impl #extractor_name {
3550 fn parse_single_action(&self, text: &str) -> Option<#enum_name> {
3551 use ::quick_xml::events::Event;
3552 use ::quick_xml::Reader;
3553
3554 let mut actions = Vec::new();
3555 let mut reader = Reader::from_str(text);
3556 reader.config_mut().trim_text(true);
3557
3558 let mut buf = Vec::new();
3559
3560 loop {
3561 match reader.read_event_into(&mut buf) {
3562 Ok(Event::Start(e)) => {
3563 let owned_e = e.into_owned();
3564 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
3565 let is_empty = false;
3566
3567 #parsing_arms
3568 }
3569 Ok(Event::Empty(e)) => {
3570 let owned_e = e.into_owned();
3571 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
3572 let is_empty = true;
3573
3574 #parsing_arms
3575 }
3576 Ok(Event::Eof) => break,
3577 Err(_) => {
3578 break;
3580 }
3581 _ => {}
3582 }
3583 buf.clear();
3584 }
3585
3586 actions.into_iter().next()
3587 }
3588
3589 pub fn extract_actions(&self, text: &str) -> Result<Vec<#enum_name>, #crate_path::intent::IntentError> {
3590 use ::quick_xml::events::Event;
3591 use ::quick_xml::Reader;
3592
3593 let mut actions = Vec::new();
3594 let mut reader = Reader::from_str(text);
3595 reader.config_mut().trim_text(true);
3596
3597 let mut buf = Vec::new();
3598
3599 loop {
3600 match reader.read_event_into(&mut buf) {
3601 Ok(Event::Start(e)) => {
3602 let owned_e = e.into_owned();
3603 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
3604 let is_empty = false;
3605
3606 #parsing_arms
3607 }
3608 Ok(Event::Empty(e)) => {
3609 let owned_e = e.into_owned();
3610 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
3611 let is_empty = true;
3612
3613 #parsing_arms
3614 }
3615 Ok(Event::Eof) => break,
3616 Err(_) => {
3617 break;
3619 }
3620 _ => {}
3621 }
3622 buf.clear();
3623 }
3624
3625 Ok(actions)
3626 }
3627
3628 pub fn transform_actions<F>(&self, text: &str, mut transformer: F) -> String
3629 where
3630 F: FnMut(#enum_name) -> String,
3631 {
3632 use ::regex::Regex;
3633
3634 let regex_pattern = #tags_regex;
3635 if regex_pattern.is_empty() {
3636 return text.to_string();
3637 }
3638
3639 let re = Regex::new(®ex_pattern).unwrap_or_else(|e| {
3640 panic!("Failed to compile regex for action tags: {}", e);
3641 });
3642
3643 re.replace_all(text, |caps: &::regex::Captures| {
3644 let matched = caps.get(0).map(|m| m.as_str()).unwrap_or("");
3645
3646 if let Some(action) = self.parse_single_action(matched) {
3648 transformer(action)
3649 } else {
3650 matched.to_string()
3652 }
3653 }).to_string()
3654 }
3655
3656 pub fn strip_actions(&self, text: &str) -> String {
3657 self.transform_actions(text, |_| String::new())
3658 }
3659 }
3660 };
3661
3662 TokenStream::from(expanded)
3663}
3664
3665fn generate_parsing_arms(
3667 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
3668 enum_name: &syn::Ident,
3669) -> proc_macro2::TokenStream {
3670 let mut arms = Vec::new();
3671
3672 for variant in variants {
3673 let variant_name = &variant.ident;
3674 let action_attrs = parse_action_attrs(&variant.attrs);
3675
3676 if let Some(tag) = action_attrs.tag {
3677 match &variant.fields {
3678 syn::Fields::Unit => {
3679 arms.push(quote! {
3681 if &tag_name == #tag {
3682 actions.push(#enum_name::#variant_name);
3683 }
3684 });
3685 }
3686 syn::Fields::Unnamed(_fields) => {
3687 arms.push(quote! {
3689 if &tag_name == #tag && !is_empty {
3690 match reader.read_text(owned_e.name()) {
3692 Ok(text) => {
3693 actions.push(#enum_name::#variant_name(text.to_string()));
3694 }
3695 Err(_) => {
3696 actions.push(#enum_name::#variant_name(String::new()));
3698 }
3699 }
3700 }
3701 });
3702 }
3703 syn::Fields::Named(fields) => {
3704 let mut field_names = Vec::new();
3706 let mut has_inner_text_field = None;
3707
3708 for field in &fields.named {
3709 let field_name = field.ident.as_ref().unwrap();
3710 let field_attrs = parse_field_action_attrs(&field.attrs);
3711
3712 if field_attrs.is_attribute {
3713 field_names.push(field_name.clone());
3714 } else if field_attrs.is_inner_text {
3715 has_inner_text_field = Some(field_name.clone());
3716 }
3717 }
3718
3719 if let Some(inner_text_field) = has_inner_text_field {
3720 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
3723 quote! {
3724 let mut #field_name = String::new();
3725 for attr in owned_e.attributes() {
3726 if let Ok(attr) = attr {
3727 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
3728 #field_name = String::from_utf8_lossy(&attr.value).to_string();
3729 break;
3730 }
3731 }
3732 }
3733 }
3734 }).collect();
3735
3736 arms.push(quote! {
3737 if &tag_name == #tag {
3738 #(#attr_extractions)*
3739
3740 if is_empty {
3742 let #inner_text_field = String::new();
3743 actions.push(#enum_name::#variant_name {
3744 #(#field_names,)*
3745 #inner_text_field,
3746 });
3747 } else {
3748 match reader.read_text(owned_e.name()) {
3750 Ok(text) => {
3751 let #inner_text_field = text.to_string();
3752 actions.push(#enum_name::#variant_name {
3753 #(#field_names,)*
3754 #inner_text_field,
3755 });
3756 }
3757 Err(_) => {
3758 let #inner_text_field = String::new();
3760 actions.push(#enum_name::#variant_name {
3761 #(#field_names,)*
3762 #inner_text_field,
3763 });
3764 }
3765 }
3766 }
3767 }
3768 });
3769 } else {
3770 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
3772 quote! {
3773 let mut #field_name = String::new();
3774 for attr in owned_e.attributes() {
3775 if let Ok(attr) = attr {
3776 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
3777 #field_name = String::from_utf8_lossy(&attr.value).to_string();
3778 break;
3779 }
3780 }
3781 }
3782 }
3783 }).collect();
3784
3785 arms.push(quote! {
3786 if &tag_name == #tag {
3787 #(#attr_extractions)*
3788 actions.push(#enum_name::#variant_name {
3789 #(#field_names),*
3790 });
3791 }
3792 });
3793 }
3794 }
3795 }
3796 }
3797 }
3798
3799 quote! {
3800 #(#arms)*
3801 }
3802}
3803
3804#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
3806pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
3807 let input = parse_macro_input!(input as DeriveInput);
3808
3809 let found_crate =
3810 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3811 let crate_path = match found_crate {
3812 FoundCrate::Itself => {
3813 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3815 quote!(::#ident)
3816 }
3817 FoundCrate::Name(name) => {
3818 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3819 quote!(::#ident)
3820 }
3821 };
3822
3823 let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
3825
3826 let struct_name = &input.ident;
3827 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
3828
3829 let placeholders = parse_template_placeholders_with_mode(&template);
3831
3832 let mut converted_template = template.clone();
3834 let mut context_fields = Vec::new();
3835
3836 let fields = match &input.data {
3838 Data::Struct(data_struct) => match &data_struct.fields {
3839 syn::Fields::Named(fields) => &fields.named,
3840 _ => panic!("ToPromptFor is only supported for structs with named fields"),
3841 },
3842 _ => panic!("ToPromptFor is only supported for structs"),
3843 };
3844
3845 let has_mode_support = input.attrs.iter().any(|attr| {
3847 if attr.path().is_ident("prompt")
3848 && let Ok(metas) =
3849 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
3850 {
3851 for meta in metas {
3852 if let Meta::NameValue(nv) = meta
3853 && nv.path.is_ident("mode")
3854 {
3855 return true;
3856 }
3857 }
3858 }
3859 false
3860 });
3861
3862 for (placeholder_name, mode_opt) in &placeholders {
3864 if placeholder_name == "self" {
3865 if let Some(specific_mode) = mode_opt {
3866 let unique_key = format!("self__{}", specific_mode);
3868
3869 let pattern = format!("{{{{ self:{} }}}}", specific_mode);
3871 let replacement = format!("{{{{ {} }}}}", unique_key);
3872 converted_template = converted_template.replace(&pattern, &replacement);
3873
3874 context_fields.push(quote! {
3876 context.insert(
3877 #unique_key.to_string(),
3878 minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
3879 );
3880 });
3881 } else {
3882 if has_mode_support {
3885 context_fields.push(quote! {
3887 context.insert(
3888 "self".to_string(),
3889 minijinja::Value::from(self.to_prompt_with_mode(mode))
3890 );
3891 });
3892 } else {
3893 context_fields.push(quote! {
3895 context.insert(
3896 "self".to_string(),
3897 minijinja::Value::from(self.to_prompt())
3898 );
3899 });
3900 }
3901 }
3902 } else {
3903 let field_exists = fields.iter().any(|f| {
3906 f.ident
3907 .as_ref()
3908 .is_some_and(|ident| ident == placeholder_name)
3909 });
3910
3911 if field_exists {
3912 let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
3913
3914 context_fields.push(quote! {
3918 context.insert(
3919 #placeholder_name.to_string(),
3920 minijinja::Value::from_serialize(&self.#field_ident)
3921 );
3922 });
3923 }
3924 }
3926 }
3927
3928 let expanded = quote! {
3929 impl #impl_generics #crate_path::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
3930 where
3931 #target_type: serde::Serialize,
3932 {
3933 fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
3934 let mut env = minijinja::Environment::new();
3936 env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
3937 panic!("Failed to parse template: {}", e)
3938 });
3939
3940 let tmpl = env.get_template("prompt").unwrap();
3941
3942 let mut context = std::collections::HashMap::new();
3944 context.insert(
3946 "self".to_string(),
3947 minijinja::Value::from_serialize(self)
3948 );
3949 context.insert(
3951 "target".to_string(),
3952 minijinja::Value::from_serialize(target)
3953 );
3954 #(#context_fields)*
3955
3956 tmpl.render(context).unwrap_or_else(|e| {
3958 format!("Failed to render prompt: {}", e)
3959 })
3960 }
3961 }
3962 };
3963
3964 TokenStream::from(expanded)
3965}
3966
3967struct AgentAttrs {
3973 expertise: Option<String>,
3974 output: Option<syn::Type>,
3975 backend: Option<String>,
3976 model: Option<String>,
3977 inner: Option<String>,
3978 default_inner: Option<String>,
3979 max_retries: Option<u32>,
3980 profile: Option<String>,
3981
3982 persona: Option<syn::Expr>,
3983}
3984
3985impl Parse for AgentAttrs {
3986 fn parse(input: ParseStream) -> syn::Result<Self> {
3987 let mut expertise = None;
3988 let mut output = None;
3989 let mut backend = None;
3990 let mut model = None;
3991 let mut inner = None;
3992 let mut default_inner = None;
3993 let mut max_retries = None;
3994 let mut profile = None;
3995 let mut persona = None;
3996
3997 let pairs = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
3998
3999 for meta in pairs {
4000 match meta {
4001 Meta::NameValue(nv) if nv.path.is_ident("expertise") => {
4002 if let syn::Expr::Lit(syn::ExprLit {
4003 lit: syn::Lit::Str(lit_str),
4004 ..
4005 }) = &nv.value
4006 {
4007 expertise = Some(lit_str.value());
4008 }
4009 }
4010 Meta::NameValue(nv) if nv.path.is_ident("output") => {
4011 if let syn::Expr::Lit(syn::ExprLit {
4012 lit: syn::Lit::Str(lit_str),
4013 ..
4014 }) = &nv.value
4015 {
4016 let ty: syn::Type = syn::parse_str(&lit_str.value())?;
4017 output = Some(ty);
4018 }
4019 }
4020 Meta::NameValue(nv) if nv.path.is_ident("backend") => {
4021 if let syn::Expr::Lit(syn::ExprLit {
4022 lit: syn::Lit::Str(lit_str),
4023 ..
4024 }) = &nv.value
4025 {
4026 backend = Some(lit_str.value());
4027 }
4028 }
4029 Meta::NameValue(nv) if nv.path.is_ident("model") => {
4030 if let syn::Expr::Lit(syn::ExprLit {
4031 lit: syn::Lit::Str(lit_str),
4032 ..
4033 }) = &nv.value
4034 {
4035 model = Some(lit_str.value());
4036 }
4037 }
4038 Meta::NameValue(nv) if nv.path.is_ident("inner") => {
4039 if let syn::Expr::Lit(syn::ExprLit {
4040 lit: syn::Lit::Str(lit_str),
4041 ..
4042 }) = &nv.value
4043 {
4044 inner = Some(lit_str.value());
4045 }
4046 }
4047 Meta::NameValue(nv) if nv.path.is_ident("default_inner") => {
4048 if let syn::Expr::Lit(syn::ExprLit {
4049 lit: syn::Lit::Str(lit_str),
4050 ..
4051 }) = &nv.value
4052 {
4053 default_inner = Some(lit_str.value());
4054 }
4055 }
4056 Meta::NameValue(nv) if nv.path.is_ident("max_retries") => {
4057 if let syn::Expr::Lit(syn::ExprLit {
4058 lit: syn::Lit::Int(lit_int),
4059 ..
4060 }) = &nv.value
4061 {
4062 max_retries = Some(lit_int.base10_parse()?);
4063 }
4064 }
4065 Meta::NameValue(nv) if nv.path.is_ident("profile") => {
4066 if let syn::Expr::Lit(syn::ExprLit {
4067 lit: syn::Lit::Str(lit_str),
4068 ..
4069 }) = &nv.value
4070 {
4071 profile = Some(lit_str.value());
4072 }
4073 }
4074 Meta::NameValue(nv) if nv.path.is_ident("persona") => {
4075 if let syn::Expr::Lit(syn::ExprLit {
4076 lit: syn::Lit::Str(lit_str),
4077 ..
4078 }) = &nv.value
4079 {
4080 let expr: syn::Expr = syn::parse_str(&lit_str.value())?;
4082 persona = Some(expr);
4083 }
4084 }
4085 _ => {}
4086 }
4087 }
4088
4089 Ok(AgentAttrs {
4090 expertise,
4091 output,
4092 backend,
4093 model,
4094 inner,
4095 default_inner,
4096 max_retries,
4097 profile,
4098 persona,
4099 })
4100 }
4101}
4102
4103fn parse_agent_attrs(attrs: &[syn::Attribute]) -> syn::Result<AgentAttrs> {
4105 for attr in attrs {
4106 if attr.path().is_ident("agent") {
4107 return attr.parse_args::<AgentAttrs>();
4108 }
4109 }
4110
4111 Ok(AgentAttrs {
4112 expertise: None,
4113 output: None,
4114 backend: None,
4115 model: None,
4116 inner: None,
4117 default_inner: None,
4118 max_retries: None,
4119 profile: None,
4120 persona: None,
4121 })
4122}
4123
4124fn generate_backend_constructors(
4126 struct_name: &syn::Ident,
4127 backend: &str,
4128 _model: Option<&str>,
4129 _profile: Option<&str>,
4130 crate_path: &proc_macro2::TokenStream,
4131) -> proc_macro2::TokenStream {
4132 match backend {
4133 "claude" => {
4134 quote! {
4135 impl #struct_name {
4136 pub fn with_claude() -> Self {
4138 Self::new(#crate_path::agent::impls::ClaudeCodeAgent::new())
4139 }
4140
4141 pub fn with_claude_model(model: &str) -> Self {
4143 Self::new(
4144 #crate_path::agent::impls::ClaudeCodeAgent::new()
4145 .with_model_str(model)
4146 )
4147 }
4148 }
4149 }
4150 }
4151 "gemini" => {
4152 quote! {
4153 impl #struct_name {
4154 pub fn with_gemini() -> Self {
4156 Self::new(#crate_path::agent::impls::GeminiAgent::new())
4157 }
4158
4159 pub fn with_gemini_model(model: &str) -> Self {
4161 Self::new(
4162 #crate_path::agent::impls::GeminiAgent::new()
4163 .with_model_str(model)
4164 )
4165 }
4166 }
4167 }
4168 }
4169 _ => quote! {},
4170 }
4171}
4172
4173fn generate_default_impl(
4175 struct_name: &syn::Ident,
4176 backend: &str,
4177 model: Option<&str>,
4178 profile: Option<&str>,
4179 crate_path: &proc_macro2::TokenStream,
4180) -> proc_macro2::TokenStream {
4181 let profile_expr = if let Some(profile_str) = profile {
4183 match profile_str.to_lowercase().as_str() {
4184 "creative" => quote! { #crate_path::agent::ExecutionProfile::Creative },
4185 "balanced" => quote! { #crate_path::agent::ExecutionProfile::Balanced },
4186 "deterministic" => quote! { #crate_path::agent::ExecutionProfile::Deterministic },
4187 _ => quote! { #crate_path::agent::ExecutionProfile::Balanced }, }
4189 } else {
4190 quote! { #crate_path::agent::ExecutionProfile::default() }
4191 };
4192
4193 let agent_init = match backend {
4194 "gemini" => {
4195 let mut builder = quote! { #crate_path::agent::impls::GeminiAgent::new() };
4196
4197 if let Some(model_str) = model {
4198 builder = quote! { #builder.with_model_str(#model_str) };
4199 }
4200
4201 builder = quote! { #builder.with_execution_profile(#profile_expr) };
4202 builder
4203 }
4204 _ => {
4205 let mut builder = quote! { #crate_path::agent::impls::ClaudeCodeAgent::new() };
4207
4208 if let Some(model_str) = model {
4209 builder = quote! { #builder.with_model_str(#model_str) };
4210 }
4211
4212 builder = quote! { #builder.with_execution_profile(#profile_expr) };
4213 builder
4214 }
4215 };
4216
4217 quote! {
4218 impl Default for #struct_name {
4219 fn default() -> Self {
4220 Self::new(#agent_init)
4221 }
4222 }
4223 }
4224}
4225
4226#[proc_macro_derive(Agent, attributes(agent))]
4235pub fn derive_agent(input: TokenStream) -> TokenStream {
4236 let input = parse_macro_input!(input as DeriveInput);
4237 let struct_name = &input.ident;
4238
4239 let agent_attrs = match parse_agent_attrs(&input.attrs) {
4241 Ok(attrs) => attrs,
4242 Err(e) => return e.to_compile_error().into(),
4243 };
4244
4245 let expertise = agent_attrs
4246 .expertise
4247 .unwrap_or_else(|| String::from("general AI assistant"));
4248 let output_type = agent_attrs
4249 .output
4250 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
4251 let backend = agent_attrs
4252 .backend
4253 .unwrap_or_else(|| String::from("claude"));
4254 let model = agent_attrs.model;
4255 let _profile = agent_attrs.profile; let max_retries = agent_attrs.max_retries.unwrap_or(3); let found_crate =
4260 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
4261 let crate_path = match found_crate {
4262 FoundCrate::Itself => {
4263 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
4265 quote!(::#ident)
4266 }
4267 FoundCrate::Name(name) => {
4268 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
4269 quote!(::#ident)
4270 }
4271 };
4272
4273 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
4274
4275 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
4277 let is_string_output = output_type_str == "String" || output_type_str == "&str";
4278
4279 let enhanced_expertise = if is_string_output {
4281 quote! { #expertise }
4283 } else {
4284 let type_name = quote!(#output_type).to_string();
4286 quote! {
4287 {
4288 use std::sync::OnceLock;
4289 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
4290
4291 EXPERTISE_CACHE.get_or_init(|| {
4292 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
4294
4295 if schema.is_empty() {
4296 format!(
4298 concat!(
4299 #expertise,
4300 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
4301 "Do not include any text outside the JSON object."
4302 ),
4303 #type_name
4304 )
4305 } else {
4306 format!(
4308 concat!(
4309 #expertise,
4310 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
4311 ),
4312 schema
4313 )
4314 }
4315 }).as_str()
4316 }
4317 }
4318 };
4319
4320 let agent_init = match backend.as_str() {
4322 "gemini" => {
4323 if let Some(model_str) = model {
4324 quote! {
4325 use #crate_path::agent::impls::GeminiAgent;
4326 let agent = GeminiAgent::new().with_model_str(#model_str);
4327 }
4328 } else {
4329 quote! {
4330 use #crate_path::agent::impls::GeminiAgent;
4331 let agent = GeminiAgent::new();
4332 }
4333 }
4334 }
4335 "claude" => {
4336 if let Some(model_str) = model {
4337 quote! {
4338 use #crate_path::agent::impls::ClaudeCodeAgent;
4339 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
4340 }
4341 } else {
4342 quote! {
4343 use #crate_path::agent::impls::ClaudeCodeAgent;
4344 let agent = ClaudeCodeAgent::new();
4345 }
4346 }
4347 }
4348 _ => {
4349 if let Some(model_str) = model {
4351 quote! {
4352 use #crate_path::agent::impls::ClaudeCodeAgent;
4353 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
4354 }
4355 } else {
4356 quote! {
4357 use #crate_path::agent::impls::ClaudeCodeAgent;
4358 let agent = ClaudeCodeAgent::new();
4359 }
4360 }
4361 }
4362 };
4363
4364 let expanded = quote! {
4365 #[async_trait::async_trait]
4366 impl #impl_generics #crate_path::agent::Agent for #struct_name #ty_generics #where_clause {
4367 type Output = #output_type;
4368
4369 fn expertise(&self) -> &str {
4370 #enhanced_expertise
4371 }
4372
4373 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
4374 #agent_init
4376
4377 let agent_ref = &agent;
4379 #crate_path::agent::retry::retry_execution(
4380 #max_retries,
4381 &intent,
4382 move |payload| {
4383 let payload = payload.clone();
4384 async move {
4385 let response = agent_ref.execute(payload).await?;
4387
4388 let json_str = #crate_path::extract_json(&response)
4390 .map_err(|e| #crate_path::agent::AgentError::ParseError {
4391 message: format!("Failed to extract JSON: {}", e),
4392 reason: #crate_path::agent::error::ParseErrorReason::MarkdownExtractionFailed,
4393 })?;
4394
4395 serde_json::from_str::<Self::Output>(&json_str)
4397 .map_err(|e| {
4398 let reason = if e.is_eof() {
4400 #crate_path::agent::error::ParseErrorReason::UnexpectedEof
4401 } else if e.is_syntax() {
4402 #crate_path::agent::error::ParseErrorReason::InvalidJson
4403 } else {
4404 #crate_path::agent::error::ParseErrorReason::SchemaMismatch
4405 };
4406
4407 #crate_path::agent::AgentError::ParseError {
4408 message: format!("Failed to parse JSON: {}", e),
4409 reason,
4410 }
4411 })
4412 }
4413 }
4414 ).await
4415 }
4416
4417 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
4418 #agent_init
4420 agent.is_available().await
4421 }
4422 }
4423 };
4424
4425 TokenStream::from(expanded)
4426}
4427
4428#[proc_macro_attribute]
4443pub fn agent(attr: TokenStream, item: TokenStream) -> TokenStream {
4444 let agent_attrs = match syn::parse::<AgentAttrs>(attr) {
4446 Ok(attrs) => attrs,
4447 Err(e) => return e.to_compile_error().into(),
4448 };
4449
4450 let input = parse_macro_input!(item as DeriveInput);
4452 let struct_name = &input.ident;
4453 let struct_name_str = struct_name.to_string();
4454 let vis = &input.vis;
4455
4456 let expertise = agent_attrs
4457 .expertise
4458 .unwrap_or_else(|| String::from("general AI assistant"));
4459 let output_type = agent_attrs
4460 .output
4461 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
4462 let backend = agent_attrs
4463 .backend
4464 .unwrap_or_else(|| String::from("claude"));
4465 let model = agent_attrs.model;
4466 let profile = agent_attrs.profile;
4467 let persona = agent_attrs.persona;
4468
4469 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
4471 let is_string_output = output_type_str == "String" || output_type_str == "&str";
4472
4473 let found_crate =
4475 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
4476 let crate_path = match found_crate {
4477 FoundCrate::Itself => {
4478 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
4479 quote!(::#ident)
4480 }
4481 FoundCrate::Name(name) => {
4482 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
4483 quote!(::#ident)
4484 }
4485 };
4486
4487 let inner_generic_name = agent_attrs.inner.unwrap_or_else(|| String::from("A"));
4489 let inner_generic_ident = syn::Ident::new(&inner_generic_name, proc_macro2::Span::call_site());
4490
4491 let default_agent_type = if let Some(ref custom_type) = agent_attrs.default_inner {
4493 let type_path: syn::Type =
4495 syn::parse_str(custom_type).expect("default_inner must be a valid type path");
4496 quote! { #type_path }
4497 } else {
4498 match backend.as_str() {
4500 "gemini" => quote! { #crate_path::agent::impls::GeminiAgent },
4501 _ => quote! { #crate_path::agent::impls::ClaudeCodeAgent },
4502 }
4503 };
4504
4505 let (struct_def, _actual_inner_type, uses_persona) = if let Some(ref _persona_path) = persona {
4507 let wrapped_type =
4510 quote! { #crate_path::agent::persona::PersonaAgent<#inner_generic_ident> };
4511 let struct_def = quote! {
4512 #vis struct #struct_name<#inner_generic_ident: #crate_path::agent::Agent + Send + Sync = #default_agent_type> {
4513 inner: #wrapped_type,
4514 }
4515 };
4516 (struct_def, wrapped_type, true)
4517 } else {
4518 let struct_def = quote! {
4520 #vis struct #struct_name<#inner_generic_ident = #default_agent_type> {
4521 inner: #inner_generic_ident,
4522 }
4523 };
4524 (struct_def, quote! { #inner_generic_ident }, false)
4525 };
4526
4527 let constructors = if let Some(ref persona_path) = persona {
4529 quote! {
4530 impl<#inner_generic_ident: #crate_path::agent::Agent + Send + Sync> #struct_name<#inner_generic_ident> {
4531 pub fn new(inner: #inner_generic_ident) -> Self {
4533 let persona_agent = #crate_path::agent::persona::PersonaAgent::new(
4534 inner,
4535 #persona_path.clone()
4536 );
4537 Self { inner: persona_agent }
4538 }
4539 }
4540 }
4541 } else {
4542 quote! {
4543 impl<#inner_generic_ident> #struct_name<#inner_generic_ident> {
4544 pub fn new(inner: #inner_generic_ident) -> Self {
4546 Self { inner }
4547 }
4548 }
4549 }
4550 };
4551
4552 let (backend_constructors, default_impl) = if let Some(ref _persona_path) = persona {
4554 let agent_init = match backend.as_str() {
4556 "gemini" => {
4557 let mut builder = quote! { #crate_path::agent::impls::GeminiAgent::new() };
4558 if let Some(model_str) = model.as_deref() {
4559 builder = quote! { #builder.with_model_str(#model_str) };
4560 }
4561 if let Some(profile_str) = profile.as_deref() {
4562 let profile_expr = match profile_str.to_lowercase().as_str() {
4563 "creative" => quote! { #crate_path::agent::ExecutionProfile::Creative },
4564 "balanced" => quote! { #crate_path::agent::ExecutionProfile::Balanced },
4565 "deterministic" => {
4566 quote! { #crate_path::agent::ExecutionProfile::Deterministic }
4567 }
4568 _ => quote! { #crate_path::agent::ExecutionProfile::Balanced },
4569 };
4570 builder = quote! { #builder.with_execution_profile(#profile_expr) };
4571 }
4572 builder
4573 }
4574 _ => {
4575 let mut builder = quote! { #crate_path::agent::impls::ClaudeCodeAgent::new() };
4576 if let Some(model_str) = model.as_deref() {
4577 builder = quote! { #builder.with_model_str(#model_str) };
4578 }
4579 if let Some(profile_str) = profile.as_deref() {
4580 let profile_expr = match profile_str.to_lowercase().as_str() {
4581 "creative" => quote! { #crate_path::agent::ExecutionProfile::Creative },
4582 "balanced" => quote! { #crate_path::agent::ExecutionProfile::Balanced },
4583 "deterministic" => {
4584 quote! { #crate_path::agent::ExecutionProfile::Deterministic }
4585 }
4586 _ => quote! { #crate_path::agent::ExecutionProfile::Balanced },
4587 };
4588 builder = quote! { #builder.with_execution_profile(#profile_expr) };
4589 }
4590 builder
4591 }
4592 };
4593
4594 let backend_constructors = match backend.as_str() {
4595 "claude" => {
4596 quote! {
4597 impl #struct_name {
4598 pub fn with_claude() -> Self {
4600 let base_agent = #crate_path::agent::impls::ClaudeCodeAgent::new();
4601 Self::new(base_agent)
4602 }
4603
4604 pub fn with_claude_model(model: &str) -> Self {
4606 let base_agent = #crate_path::agent::impls::ClaudeCodeAgent::new()
4607 .with_model_str(model);
4608 Self::new(base_agent)
4609 }
4610 }
4611 }
4612 }
4613 "gemini" => {
4614 quote! {
4615 impl #struct_name {
4616 pub fn with_gemini() -> Self {
4618 let base_agent = #crate_path::agent::impls::GeminiAgent::new();
4619 Self::new(base_agent)
4620 }
4621
4622 pub fn with_gemini_model(model: &str) -> Self {
4624 let base_agent = #crate_path::agent::impls::GeminiAgent::new()
4625 .with_model_str(model);
4626 Self::new(base_agent)
4627 }
4628 }
4629 }
4630 }
4631 _ => quote! {},
4632 };
4633
4634 let default_impl = quote! {
4635 impl Default for #struct_name {
4636 fn default() -> Self {
4637 let base_agent = #agent_init;
4638 Self::new(base_agent)
4639 }
4640 }
4641 };
4642
4643 (backend_constructors, default_impl)
4644 } else if agent_attrs.default_inner.is_some() {
4645 let default_impl = quote! {
4647 impl Default for #struct_name {
4648 fn default() -> Self {
4649 Self {
4650 inner: <#default_agent_type as Default>::default(),
4651 }
4652 }
4653 }
4654 };
4655 (quote! {}, default_impl)
4656 } else {
4657 let backend_constructors = generate_backend_constructors(
4659 struct_name,
4660 &backend,
4661 model.as_deref(),
4662 profile.as_deref(),
4663 &crate_path,
4664 );
4665 let default_impl = generate_default_impl(
4666 struct_name,
4667 &backend,
4668 model.as_deref(),
4669 profile.as_deref(),
4670 &crate_path,
4671 );
4672 (backend_constructors, default_impl)
4673 };
4674
4675 let enhanced_expertise = if is_string_output {
4677 quote! { #expertise }
4679 } else {
4680 let type_name = quote!(#output_type).to_string();
4682 quote! {
4683 {
4684 use std::sync::OnceLock;
4685 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
4686
4687 EXPERTISE_CACHE.get_or_init(|| {
4688 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
4690
4691 if schema.is_empty() {
4692 format!(
4694 concat!(
4695 #expertise,
4696 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
4697 "Do not include any text outside the JSON object."
4698 ),
4699 #type_name
4700 )
4701 } else {
4702 format!(
4704 concat!(
4705 #expertise,
4706 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
4707 ),
4708 schema
4709 )
4710 }
4711 }).as_str()
4712 }
4713 }
4714 };
4715
4716 let agent_impl = if uses_persona {
4718 quote! {
4720 #[async_trait::async_trait]
4721 impl<#inner_generic_ident> #crate_path::agent::Agent for #struct_name<#inner_generic_ident>
4722 where
4723 #inner_generic_ident: #crate_path::agent::Agent + Send + Sync,
4724 <#inner_generic_ident as #crate_path::agent::Agent>::Output: Send,
4725 {
4726 type Output = <#inner_generic_ident as #crate_path::agent::Agent>::Output;
4727
4728 fn expertise(&self) -> &str {
4729 self.inner.expertise()
4730 }
4731
4732 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
4733 self.inner.execute(intent).await
4734 }
4735
4736 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
4737 self.inner.is_available().await
4738 }
4739 }
4740 }
4741 } else {
4742 quote! {
4744 #[async_trait::async_trait]
4745 impl<#inner_generic_ident> #crate_path::agent::Agent for #struct_name<#inner_generic_ident>
4746 where
4747 #inner_generic_ident: #crate_path::agent::Agent<Output = String>,
4748 {
4749 type Output = #output_type;
4750
4751 fn expertise(&self) -> &str {
4752 #enhanced_expertise
4753 }
4754
4755 #[#crate_path::tracing::instrument(name = "agent.execute", skip_all, fields(agent.name = #struct_name_str, agent.expertise = self.expertise()))]
4756 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
4757 let enhanced_payload = intent.prepend_text(self.expertise());
4759
4760 let response = self.inner.execute(enhanced_payload).await?;
4762
4763 let json_str = #crate_path::extract_json(&response)
4765 .map_err(|e| #crate_path::agent::AgentError::ParseError {
4766 message: e.to_string(),
4767 reason: #crate_path::agent::error::ParseErrorReason::MarkdownExtractionFailed,
4768 })?;
4769
4770 serde_json::from_str(&json_str).map_err(|e| {
4772 let reason = if e.is_eof() {
4773 #crate_path::agent::error::ParseErrorReason::UnexpectedEof
4774 } else if e.is_syntax() {
4775 #crate_path::agent::error::ParseErrorReason::InvalidJson
4776 } else {
4777 #crate_path::agent::error::ParseErrorReason::SchemaMismatch
4778 };
4779 #crate_path::agent::AgentError::ParseError {
4780 message: e.to_string(),
4781 reason,
4782 }
4783 })
4784 }
4785
4786 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
4787 self.inner.is_available().await
4788 }
4789 }
4790 }
4791 };
4792
4793 let expanded = quote! {
4794 #struct_def
4795 #constructors
4796 #backend_constructors
4797 #default_impl
4798 #agent_impl
4799 };
4800
4801 TokenStream::from(expanded)
4802}
4803
4804#[proc_macro_derive(TypeMarker)]
4826pub fn derive_type_marker(input: TokenStream) -> TokenStream {
4827 let input = parse_macro_input!(input as DeriveInput);
4828 let struct_name = &input.ident;
4829 let type_name_str = struct_name.to_string();
4830
4831 let found_crate =
4833 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
4834 let crate_path = match found_crate {
4835 FoundCrate::Itself => {
4836 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
4837 quote!(::#ident)
4838 }
4839 FoundCrate::Name(name) => {
4840 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
4841 quote!(::#ident)
4842 }
4843 };
4844
4845 let expanded = quote! {
4846 impl #crate_path::orchestrator::TypeMarker for #struct_name {
4847 const TYPE_NAME: &'static str = #type_name_str;
4848 }
4849 };
4850
4851 TokenStream::from(expanded)
4852}
4853
4854#[proc_macro_attribute]
4890pub fn type_marker(_attr: TokenStream, item: TokenStream) -> TokenStream {
4891 let input = parse_macro_input!(item as syn::DeriveInput);
4892 let struct_name = &input.ident;
4893 let vis = &input.vis;
4894 let type_name_str = struct_name.to_string();
4895
4896 let default_fn_name = syn::Ident::new(
4898 &format!("default_{}_type", to_snake_case(&type_name_str)),
4899 struct_name.span(),
4900 );
4901
4902 let found_crate =
4904 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
4905 let crate_path = match found_crate {
4906 FoundCrate::Itself => {
4907 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
4908 quote!(::#ident)
4909 }
4910 FoundCrate::Name(name) => {
4911 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
4912 quote!(::#ident)
4913 }
4914 };
4915
4916 let fields = match &input.data {
4918 syn::Data::Struct(data_struct) => match &data_struct.fields {
4919 syn::Fields::Named(fields) => &fields.named,
4920 _ => {
4921 return syn::Error::new_spanned(
4922 struct_name,
4923 "type_marker only works with structs with named fields",
4924 )
4925 .to_compile_error()
4926 .into();
4927 }
4928 },
4929 _ => {
4930 return syn::Error::new_spanned(struct_name, "type_marker only works with structs")
4931 .to_compile_error()
4932 .into();
4933 }
4934 };
4935
4936 let mut new_fields = vec![];
4938
4939 let default_fn_name_str = default_fn_name.to_string();
4941 let default_fn_name_lit = syn::LitStr::new(&default_fn_name_str, default_fn_name.span());
4942
4943 new_fields.push(quote! {
4948 #[serde(default = #default_fn_name_lit)]
4949 __type: String
4950 });
4951
4952 for field in fields {
4954 new_fields.push(quote! { #field });
4955 }
4956
4957 let attrs = &input.attrs;
4959 let generics = &input.generics;
4960
4961 let expanded = quote! {
4962 fn #default_fn_name() -> String {
4964 #type_name_str.to_string()
4965 }
4966
4967 #(#attrs)*
4969 #vis struct #struct_name #generics {
4970 #(#new_fields),*
4971 }
4972
4973 impl #crate_path::orchestrator::TypeMarker for #struct_name {
4975 const TYPE_NAME: &'static str = #type_name_str;
4976 }
4977 };
4978
4979 TokenStream::from(expanded)
4980}