1use proc_macro::TokenStream;
2use proc_macro_crate::{FoundCrate, crate_name};
3use quote::quote;
4use regex::Regex;
5use syn::{
6 Data, DeriveInput, Meta, Token,
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 punctuated::Punctuated,
10};
11
12fn parse_template_placeholders_with_mode(template: &str) -> Vec<(String, Option<String>)> {
15 let mut placeholders = Vec::new();
16 let mut seen_fields = std::collections::HashSet::new();
17
18 let mode_pattern = Regex::new(r"\{\{\s*(\w+)\s*:\s*(\w+)\s*\}\}").unwrap();
20 for cap in mode_pattern.captures_iter(template) {
21 let field_name = cap[1].to_string();
22 let mode = cap[2].to_string();
23 placeholders.push((field_name.clone(), Some(mode)));
24 seen_fields.insert(field_name);
25 }
26
27 let standard_pattern = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
29 for cap in standard_pattern.captures_iter(template) {
30 let field_name = cap[1].to_string();
31 if !seen_fields.contains(&field_name) {
33 placeholders.push((field_name, None));
34 }
35 }
36
37 placeholders
38}
39
40fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
42 attrs
43 .iter()
44 .filter_map(|attr| {
45 if attr.path().is_ident("doc")
46 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
47 && let syn::Expr::Lit(syn::ExprLit {
48 lit: syn::Lit::Str(lit_str),
49 ..
50 }) = &meta_name_value.value
51 {
52 return Some(lit_str.value());
53 }
54 None
55 })
56 .map(|s| s.trim().to_string())
57 .collect::<Vec<_>>()
58 .join(" ")
59}
60
61fn generate_example_only_parts(
63 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
64 has_default: bool,
65 crate_path: &proc_macro2::TokenStream,
66) -> proc_macro2::TokenStream {
67 let mut field_values = Vec::new();
68
69 for field in fields.iter() {
70 let field_name = field.ident.as_ref().unwrap();
71 let field_name_str = field_name.to_string();
72 let attrs = parse_field_prompt_attrs(&field.attrs);
73
74 if field_name_str == "__type" {
78 continue;
79 }
80
81 if attrs.skip {
83 continue;
84 }
85
86 if let Some(example) = attrs.example {
88 field_values.push(quote! {
90 json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
91 });
92 } else if has_default {
93 field_values.push(quote! {
95 let default_value = serde_json::to_value(&default_instance.#field_name)
96 .unwrap_or(serde_json::Value::Null);
97 json_obj.insert(#field_name_str.to_string(), default_value);
98 });
99 } else {
100 field_values.push(quote! {
102 let value = serde_json::to_value(&self.#field_name)
103 .unwrap_or(serde_json::Value::Null);
104 json_obj.insert(#field_name_str.to_string(), value);
105 });
106 }
107 }
108
109 if has_default {
110 quote! {
111 {
112 let default_instance = Self::default();
113 let mut json_obj = serde_json::Map::new();
114 #(#field_values)*
115 let json_value = serde_json::Value::Object(json_obj);
116 let json_str = serde_json::to_string_pretty(&json_value)
117 .unwrap_or_else(|_| "{}".to_string());
118 vec![#crate_path::prompt::PromptPart::Text(json_str)]
119 }
120 }
121 } else {
122 quote! {
123 {
124 let mut json_obj = serde_json::Map::new();
125 #(#field_values)*
126 let json_value = serde_json::Value::Object(json_obj);
127 let json_str = serde_json::to_string_pretty(&json_value)
128 .unwrap_or_else(|_| "{}".to_string());
129 vec![#crate_path::prompt::PromptPart::Text(json_str)]
130 }
131 }
132 }
133}
134
135fn generate_schema_only_parts(
137 struct_name: &str,
138 struct_docs: &str,
139 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
140 crate_path: &proc_macro2::TokenStream,
141 _has_type_marker: bool,
142) -> proc_macro2::TokenStream {
143 let mut field_schema_parts = vec![];
144 let mut nested_type_collectors = vec![];
145
146 for field in fields.iter() {
148 let field_name = field.ident.as_ref().unwrap();
149 let field_name_str = field_name.to_string();
150 let attrs = parse_field_prompt_attrs(&field.attrs);
151
152 if field_name_str == "__type" {
156 continue;
157 }
158
159 if attrs.skip {
161 continue;
162 }
163
164 let field_docs = extract_doc_comments(&field.attrs);
166
167 let (is_vec, vec_inner_type) = extract_vec_inner_type(&field.ty);
169 let (is_option, option_inner_type) = extract_option_inner_type(&field.ty);
170 let (is_map, map_value_type) = extract_map_value_type(&field.ty);
171 let (is_set, set_element_type) = extract_set_element_type(&field.ty);
172
173 if is_vec {
174 let comment = if !field_docs.is_empty() {
177 format!(" // {}", field_docs)
178 } else {
179 String::new()
180 };
181
182 field_schema_parts.push(quote! {
183 {
184 let type_name = stringify!(#vec_inner_type);
185 format!(" {}: {}[];{}", #field_name_str, type_name, #comment)
186 }
187 });
188
189 if let Some(inner) = vec_inner_type
191 && !is_primitive_type(inner)
192 {
193 nested_type_collectors.push(quote! {
194 <#inner as #crate_path::prompt::ToPrompt>::prompt_schema()
195 });
196 }
197 } else if is_option {
198 let comment = if !field_docs.is_empty() {
201 format!(" // {}", field_docs)
202 } else {
203 String::new()
204 };
205
206 if let Some(inner) = option_inner_type {
207 let (inner_is_map, inner_map_value) = extract_map_value_type(inner);
209 let (inner_is_set, inner_set_element) = extract_set_element_type(inner);
210 let (inner_is_vec, inner_vec_element) = extract_vec_inner_type(inner);
211
212 if inner_is_map {
213 if let Some(value_type) = inner_map_value {
215 let is_value_primitive = is_primitive_type(value_type);
216 if !is_value_primitive {
217 field_schema_parts.push(quote! {
218 {
219 let type_name = stringify!(#value_type);
220 format!(" {}: Record<string, {}> | null;{}", #field_name_str, type_name, #comment)
221 }
222 });
223 nested_type_collectors.push(quote! {
224 <#value_type as #crate_path::prompt::ToPrompt>::prompt_schema()
225 });
226 } else {
227 let type_str = format_type_for_schema(value_type);
228 field_schema_parts.push(quote! {
229 format!(" {}: Record<string, {}> | null;{}", #field_name_str, #type_str, #comment)
230 });
231 }
232 }
233 } else if inner_is_set {
234 if let Some(element_type) = inner_set_element {
236 let is_element_primitive = is_primitive_type(element_type);
237 if !is_element_primitive {
238 field_schema_parts.push(quote! {
239 {
240 let type_name = stringify!(#element_type);
241 format!(" {}: {}[] | null;{}", #field_name_str, type_name, #comment)
242 }
243 });
244 nested_type_collectors.push(quote! {
245 <#element_type as #crate_path::prompt::ToPrompt>::prompt_schema()
246 });
247 } else {
248 let type_str = format_type_for_schema(element_type);
249 field_schema_parts.push(quote! {
250 format!(" {}: {}[] | null;{}", #field_name_str, #type_str, #comment)
251 });
252 }
253 }
254 } else if inner_is_vec {
255 if let Some(element_type) = inner_vec_element {
257 let is_element_primitive = is_primitive_type(element_type);
258 if !is_element_primitive {
259 field_schema_parts.push(quote! {
260 {
261 let type_name = stringify!(#element_type);
262 format!(" {}: {}[] | null;{}", #field_name_str, type_name, #comment)
263 }
264 });
265 nested_type_collectors.push(quote! {
266 <#element_type as #crate_path::prompt::ToPrompt>::prompt_schema()
267 });
268 } else {
269 let type_str = format_type_for_schema(element_type);
270 field_schema_parts.push(quote! {
271 format!(" {}: {}[] | null;{}", #field_name_str, #type_str, #comment)
272 });
273 }
274 }
275 } else {
276 let is_inner_primitive = is_primitive_type(inner);
278
279 if !is_inner_primitive {
280 field_schema_parts.push(quote! {
282 {
283 let type_name = stringify!(#inner);
284 format!(" {}: {} | null;{}", #field_name_str, type_name, #comment)
285 }
286 });
287
288 nested_type_collectors.push(quote! {
290 <#inner as #crate_path::prompt::ToPrompt>::prompt_schema()
291 });
292 } else {
293 let type_str = format_type_for_schema(inner);
295 field_schema_parts.push(quote! {
296 format!(" {}: {} | null;{}", #field_name_str, #type_str, #comment)
297 });
298 }
299 }
300 }
301 } else if is_map {
302 let comment = if !field_docs.is_empty() {
305 format!(" // {}", field_docs)
306 } else {
307 String::new()
308 };
309
310 if let Some(value_type) = map_value_type {
311 let is_value_primitive = is_primitive_type(value_type);
312
313 if !is_value_primitive {
314 field_schema_parts.push(quote! {
316 {
317 let type_name = stringify!(#value_type);
318 format!(" {}: Record<string, {}>;{}", #field_name_str, type_name, #comment)
319 }
320 });
321
322 nested_type_collectors.push(quote! {
324 <#value_type as #crate_path::prompt::ToPrompt>::prompt_schema()
325 });
326 } else {
327 let type_str = format_type_for_schema(value_type);
329 field_schema_parts.push(quote! {
330 format!(" {}: Record<string, {}>;{}", #field_name_str, #type_str, #comment)
331 });
332 }
333 }
334 } else if is_set {
335 let comment = if !field_docs.is_empty() {
338 format!(" // {}", field_docs)
339 } else {
340 String::new()
341 };
342
343 if let Some(element_type) = set_element_type {
344 let is_element_primitive = is_primitive_type(element_type);
345
346 if !is_element_primitive {
347 field_schema_parts.push(quote! {
349 {
350 let type_name = stringify!(#element_type);
351 format!(" {}: {}[];{}", #field_name_str, type_name, #comment)
352 }
353 });
354
355 nested_type_collectors.push(quote! {
357 <#element_type as #crate_path::prompt::ToPrompt>::prompt_schema()
358 });
359 } else {
360 let type_str = format_type_for_schema(element_type);
362 field_schema_parts.push(quote! {
363 format!(" {}: {}[];{}", #field_name_str, #type_str, #comment)
364 });
365 }
366 }
367 } else {
368 let field_type = &field.ty;
370 let is_primitive = is_primitive_type(field_type);
371
372 if !is_primitive {
373 let comment = if !field_docs.is_empty() {
376 format!(" // {}", field_docs)
377 } else {
378 String::new()
379 };
380
381 field_schema_parts.push(quote! {
382 {
383 let type_name = stringify!(#field_type);
384 format!(" {}: {};{}", #field_name_str, type_name, #comment)
385 }
386 });
387
388 nested_type_collectors.push(quote! {
390 <#field_type as #crate_path::prompt::ToPrompt>::prompt_schema()
391 });
392 } else {
393 let type_str = format_type_for_schema(&field.ty);
396 let comment = if !field_docs.is_empty() {
397 format!(" // {}", field_docs)
398 } else {
399 String::new()
400 };
401
402 field_schema_parts.push(quote! {
403 format!(" {}: {};{}", #field_name_str, #type_str, #comment)
404 });
405 }
406 }
407 }
408
409 let mut header_lines = Vec::new();
424
425 if !struct_docs.is_empty() {
427 header_lines.push("/**".to_string());
428 header_lines.push(format!(" * {}", struct_docs));
429 header_lines.push(" */".to_string());
430 }
431
432 header_lines.push(format!("type {} = {{", struct_name));
434
435 quote! {
436 {
437 let mut all_lines: Vec<String> = Vec::new();
438
439 let nested_schemas: Vec<String> = vec![#(#nested_type_collectors),*];
441 let mut seen_types = std::collections::HashSet::<String>::new();
442
443 for schema in nested_schemas {
444 if !schema.is_empty() {
445 if seen_types.insert(schema.clone()) {
447 all_lines.push(schema);
448 all_lines.push(String::new()); }
450 }
451 }
452
453 let mut lines: Vec<String> = Vec::new();
455 #(lines.push(#header_lines.to_string());)*
456 #(lines.push(#field_schema_parts);)*
457 lines.push("}".to_string());
458 all_lines.push(lines.join("\n"));
459
460 vec![#crate_path::prompt::PromptPart::Text(all_lines.join("\n"))]
461 }
462 }
463}
464
465fn extract_vec_inner_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
467 if let syn::Type::Path(type_path) = ty
468 && let Some(last_segment) = type_path.path.segments.last()
469 && last_segment.ident == "Vec"
470 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
471 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
472 {
473 return (true, Some(inner_type));
474 }
475 (false, None)
476}
477
478fn extract_option_inner_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
480 if let syn::Type::Path(type_path) = ty
481 && let Some(last_segment) = type_path.path.segments.last()
482 && last_segment.ident == "Option"
483 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
484 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
485 {
486 return (true, Some(inner_type));
487 }
488 (false, None)
489}
490
491fn extract_map_value_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
493 if let syn::Type::Path(type_path) = ty
494 && let Some(last_segment) = type_path.path.segments.last()
495 && (last_segment.ident == "HashMap" || last_segment.ident == "BTreeMap")
496 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
497 && let Some(syn::GenericArgument::Type(value_type)) = args.args.iter().nth(1)
498 {
499 return (true, Some(value_type));
500 }
501 (false, None)
502}
503
504fn extract_set_element_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
506 if let syn::Type::Path(type_path) = ty
507 && let Some(last_segment) = type_path.path.segments.last()
508 && (last_segment.ident == "HashSet" || last_segment.ident == "BTreeSet")
509 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
510 && let Some(syn::GenericArgument::Type(element_type)) = args.args.first()
511 {
512 return (true, Some(element_type));
513 }
514 (false, None)
515}
516
517fn extract_expandable_type(ty: &syn::Type) -> &syn::Type {
520 if let syn::Type::Path(type_path) = ty
521 && let Some(last_segment) = type_path.path.segments.last()
522 {
523 let type_name = last_segment.ident.to_string();
524
525 if type_name == "Option"
527 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
528 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
529 {
530 return extract_expandable_type(inner_type);
531 }
532
533 if type_name == "Vec"
535 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
536 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
537 {
538 return extract_expandable_type(inner_type);
539 }
540
541 if (type_name == "HashSet" || type_name == "BTreeSet")
543 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
544 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
545 {
546 return extract_expandable_type(inner_type);
547 }
548
549 if (type_name == "HashMap" || type_name == "BTreeMap")
552 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
553 && let Some(syn::GenericArgument::Type(value_type)) = args.args.iter().nth(1)
554 {
555 return extract_expandable_type(value_type);
556 }
557 }
558
559 ty
561}
562
563fn is_primitive_type(ty: &syn::Type) -> bool {
565 if let syn::Type::Path(type_path) = ty
566 && let Some(last_segment) = type_path.path.segments.last()
567 {
568 let type_name = last_segment.ident.to_string();
569
570 if type_name == "Option"
572 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
573 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
574 {
575 return is_primitive_type(inner_type);
576 }
577
578 if type_name == "Vec"
580 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
581 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
582 {
583 return is_primitive_type(inner_type);
584 }
585
586 if (type_name == "HashSet" || type_name == "BTreeSet")
588 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
589 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
590 {
591 return is_primitive_type(inner_type);
592 }
593
594 if (type_name == "HashMap" || type_name == "BTreeMap")
597 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
598 && let Some(syn::GenericArgument::Type(value_type)) = args.args.iter().nth(1)
599 {
600 return is_primitive_type(value_type);
601 }
602
603 matches!(
604 type_name.as_str(),
605 "String"
606 | "str"
607 | "i8"
608 | "i16"
609 | "i32"
610 | "i64"
611 | "i128"
612 | "isize"
613 | "u8"
614 | "u16"
615 | "u32"
616 | "u64"
617 | "u128"
618 | "usize"
619 | "f32"
620 | "f64"
621 | "bool"
622 )
623 } else {
624 true
626 }
627}
628
629fn format_type_for_schema(ty: &syn::Type) -> String {
631 match ty {
633 syn::Type::Path(type_path) => {
634 let path = &type_path.path;
635 if let Some(last_segment) = path.segments.last() {
636 let type_name = last_segment.ident.to_string();
637
638 if type_name == "Option"
640 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
641 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
642 {
643 return format!("{} | null", format_type_for_schema(inner_type));
644 }
645
646 match type_name.as_str() {
648 "String" | "str" => "string".to_string(),
649 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
650 | "u64" | "u128" | "usize" => "number".to_string(),
651 "f32" | "f64" => "number".to_string(),
652 "bool" => "boolean".to_string(),
653 "Vec" => {
654 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
655 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
656 {
657 return format!("{}[]", format_type_for_schema(inner_type));
658 }
659 "array".to_string()
660 }
661 _ => type_name,
663 }
664 } else {
665 "unknown".to_string()
666 }
667 }
668 _ => "unknown".to_string(),
669 }
670}
671
672#[derive(Default)]
674struct PromptAttributes {
675 skip: bool,
676 rename: Option<String>,
677 description: Option<String>,
678}
679
680fn parse_prompt_attributes(attrs: &[syn::Attribute]) -> PromptAttributes {
683 let mut result = PromptAttributes::default();
684
685 for attr in attrs {
686 if attr.path().is_ident("prompt") {
687 if let Ok(meta_list) = attr.meta.require_list() {
689 if let Ok(metas) =
691 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
692 {
693 for meta in metas {
694 if let Meta::NameValue(nv) = meta {
695 if nv.path.is_ident("rename") {
696 if let syn::Expr::Lit(syn::ExprLit {
697 lit: syn::Lit::Str(lit_str),
698 ..
699 }) = nv.value
700 {
701 result.rename = Some(lit_str.value());
702 }
703 } else if nv.path.is_ident("description")
704 && let syn::Expr::Lit(syn::ExprLit {
705 lit: syn::Lit::Str(lit_str),
706 ..
707 }) = nv.value
708 {
709 result.description = Some(lit_str.value());
710 }
711 } else if let Meta::Path(path) = meta
712 && path.is_ident("skip")
713 {
714 result.skip = true;
715 }
716 }
717 }
718
719 let tokens_str = meta_list.tokens.to_string();
721 if tokens_str == "skip" {
722 result.skip = true;
723 }
724 }
725
726 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
728 result.description = Some(lit_str.value());
729 }
730 }
731 }
732 result
733}
734
735fn generate_example_value_for_type(type_str: &str) -> String {
737 match type_str {
738 "string" => "\"example\"".to_string(),
739 "number" => "0".to_string(),
740 "boolean" => "false".to_string(),
741 s if s.ends_with("[]") => "[]".to_string(),
742 s if s.contains("|") => {
743 let first_type = s.split('|').next().unwrap().trim();
745 generate_example_value_for_type(first_type)
746 }
747 custom_type => {
748 format!("\"<See {}>\"", custom_type)
752 }
753 }
754}
755
756fn parse_serde_variant_rename(attrs: &[syn::Attribute]) -> Option<String> {
758 for attr in attrs {
759 if attr.path().is_ident("serde")
760 && let Ok(meta_list) = attr.meta.require_list()
761 && let Ok(metas) =
762 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
763 {
764 for meta in metas {
765 if let Meta::NameValue(nv) = meta
766 && nv.path.is_ident("rename")
767 && let syn::Expr::Lit(syn::ExprLit {
768 lit: syn::Lit::Str(lit_str),
769 ..
770 }) = nv.value
771 {
772 return Some(lit_str.value());
773 }
774 }
775 }
776 }
777 None
778}
779
780#[derive(Debug, Clone, Copy, PartialEq, Eq)]
782enum RenameRule {
783 #[allow(dead_code)]
784 None,
785 LowerCase,
786 UpperCase,
787 PascalCase,
788 CamelCase,
789 SnakeCase,
790 ScreamingSnakeCase,
791 KebabCase,
792 ScreamingKebabCase,
793}
794
795impl RenameRule {
796 fn from_str(s: &str) -> Option<Self> {
798 match s {
799 "lowercase" => Some(Self::LowerCase),
800 "UPPERCASE" => Some(Self::UpperCase),
801 "PascalCase" => Some(Self::PascalCase),
802 "camelCase" => Some(Self::CamelCase),
803 "snake_case" => Some(Self::SnakeCase),
804 "SCREAMING_SNAKE_CASE" => Some(Self::ScreamingSnakeCase),
805 "kebab-case" => Some(Self::KebabCase),
806 "SCREAMING-KEBAB-CASE" => Some(Self::ScreamingKebabCase),
807 _ => None,
808 }
809 }
810
811 fn apply(&self, name: &str) -> String {
813 match self {
814 Self::None => name.to_string(),
815 Self::LowerCase => name.to_lowercase(),
816 Self::UpperCase => name.to_uppercase(),
817 Self::PascalCase => name.to_string(), Self::CamelCase => {
819 let mut chars = name.chars();
821 match chars.next() {
822 None => String::new(),
823 Some(first) => first.to_lowercase().chain(chars).collect(),
824 }
825 }
826 Self::SnakeCase => {
827 let mut result = String::new();
829 for (i, ch) in name.chars().enumerate() {
830 if ch.is_uppercase() && i > 0 {
831 result.push('_');
832 }
833 result.push(ch.to_lowercase().next().unwrap());
834 }
835 result
836 }
837 Self::ScreamingSnakeCase => {
838 let mut result = String::new();
840 for (i, ch) in name.chars().enumerate() {
841 if ch.is_uppercase() && i > 0 {
842 result.push('_');
843 }
844 result.push(ch.to_uppercase().next().unwrap());
845 }
846 result
847 }
848 Self::KebabCase => {
849 let mut result = String::new();
851 for (i, ch) in name.chars().enumerate() {
852 if ch.is_uppercase() && i > 0 {
853 result.push('-');
854 }
855 result.push(ch.to_lowercase().next().unwrap());
856 }
857 result
858 }
859 Self::ScreamingKebabCase => {
860 let mut result = String::new();
862 for (i, ch) in name.chars().enumerate() {
863 if ch.is_uppercase() && i > 0 {
864 result.push('-');
865 }
866 result.push(ch.to_uppercase().next().unwrap());
867 }
868 result
869 }
870 }
871 }
872}
873
874fn parse_serde_rename_all(attrs: &[syn::Attribute]) -> Option<RenameRule> {
876 for attr in attrs {
877 if attr.path().is_ident("serde")
878 && let Ok(meta_list) = attr.meta.require_list()
879 {
880 if let Ok(metas) =
882 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
883 {
884 for meta in metas {
885 if let Meta::NameValue(nv) = meta
886 && nv.path.is_ident("rename_all")
887 && let syn::Expr::Lit(syn::ExprLit {
888 lit: syn::Lit::Str(lit_str),
889 ..
890 }) = nv.value
891 {
892 return RenameRule::from_str(&lit_str.value());
893 }
894 }
895 }
896 }
897 }
898 None
899}
900
901fn parse_serde_tag(attrs: &[syn::Attribute]) -> Option<String> {
904 for attr in attrs {
905 if attr.path().is_ident("serde")
906 && let Ok(meta_list) = attr.meta.require_list()
907 {
908 if let Ok(metas) =
910 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
911 {
912 for meta in metas {
913 if let Meta::NameValue(nv) = meta
914 && nv.path.is_ident("tag")
915 && let syn::Expr::Lit(syn::ExprLit {
916 lit: syn::Lit::Str(lit_str),
917 ..
918 }) = nv.value
919 {
920 return Some(lit_str.value());
921 }
922 }
923 }
924 }
925 }
926 None
927}
928
929fn parse_serde_untagged(attrs: &[syn::Attribute]) -> bool {
932 for attr in attrs {
933 if attr.path().is_ident("serde")
934 && let Ok(meta_list) = attr.meta.require_list()
935 {
936 if let Ok(metas) =
938 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
939 {
940 for meta in metas {
941 if let Meta::Path(path) = meta
942 && path.is_ident("untagged")
943 {
944 return true;
945 }
946 }
947 }
948 }
949 }
950 false
951}
952
953#[derive(Debug, Default)]
955struct FieldPromptAttrs {
956 skip: bool,
957 rename: Option<String>,
958 format_with: Option<String>,
959 image: bool,
960 example: Option<String>,
961}
962
963fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
965 let mut result = FieldPromptAttrs::default();
966
967 for attr in attrs {
968 if attr.path().is_ident("prompt") {
969 if let Ok(meta_list) = attr.meta.require_list() {
971 if let Ok(metas) =
973 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
974 {
975 for meta in metas {
976 match meta {
977 Meta::Path(path) if path.is_ident("skip") => {
978 result.skip = true;
979 }
980 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
981 if let syn::Expr::Lit(syn::ExprLit {
982 lit: syn::Lit::Str(lit_str),
983 ..
984 }) = nv.value
985 {
986 result.rename = Some(lit_str.value());
987 }
988 }
989 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
990 if let syn::Expr::Lit(syn::ExprLit {
991 lit: syn::Lit::Str(lit_str),
992 ..
993 }) = nv.value
994 {
995 result.format_with = Some(lit_str.value());
996 }
997 }
998 Meta::Path(path) if path.is_ident("image") => {
999 result.image = true;
1000 }
1001 Meta::NameValue(nv) if nv.path.is_ident("example") => {
1002 if let syn::Expr::Lit(syn::ExprLit {
1003 lit: syn::Lit::Str(lit_str),
1004 ..
1005 }) = nv.value
1006 {
1007 result.example = Some(lit_str.value());
1008 }
1009 }
1010 _ => {}
1011 }
1012 }
1013 } else if meta_list.tokens.to_string() == "skip" {
1014 result.skip = true;
1016 } else if meta_list.tokens.to_string() == "image" {
1017 result.image = true;
1019 }
1020 }
1021 }
1022 }
1023
1024 result
1025}
1026
1027#[proc_macro_derive(ToPrompt, attributes(prompt))]
1070pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
1071 let input = parse_macro_input!(input as DeriveInput);
1072
1073 let found_crate =
1074 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1075 let crate_path = match found_crate {
1076 FoundCrate::Itself => {
1077 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1079 quote!(::#ident)
1080 }
1081 FoundCrate::Name(name) => {
1082 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1083 quote!(::#ident)
1084 }
1085 };
1086
1087 match &input.data {
1089 Data::Enum(data_enum) => {
1090 let enum_name = &input.ident;
1092 let enum_docs = extract_doc_comments(&input.attrs);
1093
1094 let serde_tag = parse_serde_tag(&input.attrs);
1096 let is_internally_tagged = serde_tag.is_some();
1097 let is_untagged = parse_serde_untagged(&input.attrs);
1098
1099 let rename_rule = parse_serde_rename_all(&input.attrs);
1101
1102 let mut variant_lines = Vec::new();
1115 let mut first_variant_name = None;
1116
1117 let mut example_unit: Option<String> = None;
1119 let mut example_struct: Option<String> = None;
1120 let mut example_tuple: Option<String> = None;
1121
1122 let mut nested_types: Vec<&syn::Type> = Vec::new();
1124
1125 for variant in &data_enum.variants {
1126 let variant_name = &variant.ident;
1127 let variant_name_str = variant_name.to_string();
1128
1129 let prompt_attrs = parse_prompt_attributes(&variant.attrs);
1131
1132 if prompt_attrs.skip {
1134 continue;
1135 }
1136
1137 let variant_value = if let Some(prompt_rename) = &prompt_attrs.rename {
1143 prompt_rename.clone()
1144 } else if let Some(serde_rename) = parse_serde_variant_rename(&variant.attrs) {
1145 serde_rename
1146 } else if let Some(rule) = rename_rule {
1147 rule.apply(&variant_name_str)
1148 } else {
1149 variant_name_str.clone()
1150 };
1151
1152 let variant_line = match &variant.fields {
1154 syn::Fields::Unit => {
1155 if example_unit.is_none() {
1157 example_unit = Some(format!("\"{}\"", variant_value));
1158 }
1159
1160 if let Some(desc) = &prompt_attrs.description {
1162 format!(" | \"{}\" // {}", variant_value, desc)
1163 } else {
1164 let docs = extract_doc_comments(&variant.attrs);
1165 if !docs.is_empty() {
1166 format!(" | \"{}\" // {}", variant_value, docs)
1167 } else {
1168 format!(" | \"{}\"", variant_value)
1169 }
1170 }
1171 }
1172 syn::Fields::Named(fields) => {
1173 let mut field_parts = Vec::new();
1174 let mut example_field_parts = Vec::new();
1175
1176 if is_internally_tagged && let Some(tag_name) = &serde_tag {
1178 field_parts.push(format!("{}: \"{}\"", tag_name, variant_value));
1179 example_field_parts
1180 .push(format!("{}: \"{}\"", tag_name, variant_value));
1181 }
1182
1183 for field in &fields.named {
1184 let field_name = field.ident.as_ref().unwrap().to_string();
1185 let field_type = format_type_for_schema(&field.ty);
1186 field_parts.push(format!("{}: {}", field_name, field_type.clone()));
1187
1188 let expandable_type = extract_expandable_type(&field.ty);
1191 if !is_primitive_type(expandable_type) {
1192 nested_types.push(expandable_type);
1193 }
1194
1195 let example_value = generate_example_value_for_type(&field_type);
1197 example_field_parts.push(format!("{}: {}", field_name, example_value));
1198 }
1199
1200 let field_str = field_parts.join(", ");
1201 let example_field_str = example_field_parts.join(", ");
1202
1203 if example_struct.is_none() {
1205 if is_untagged || is_internally_tagged {
1206 example_struct = Some(format!("{{ {} }}", example_field_str));
1207 } else {
1208 example_struct = Some(format!(
1209 "{{ \"{}\": {{ {} }} }}",
1210 variant_value, example_field_str
1211 ));
1212 }
1213 }
1214
1215 let comment = if let Some(desc) = &prompt_attrs.description {
1216 format!(" // {}", desc)
1217 } else {
1218 let docs = extract_doc_comments(&variant.attrs);
1219 if !docs.is_empty() {
1220 format!(" // {}", docs)
1221 } else if is_untagged {
1222 format!(" // {}", variant_value)
1224 } else {
1225 String::new()
1226 }
1227 };
1228
1229 if is_untagged {
1230 format!(" | {{ {} }}{}", field_str, comment)
1232 } else if is_internally_tagged {
1233 format!(" | {{ {} }}{}", field_str, comment)
1235 } else {
1236 format!(
1238 " | {{ \"{}\": {{ {} }} }}{}",
1239 variant_value, field_str, comment
1240 )
1241 }
1242 }
1243 syn::Fields::Unnamed(fields) => {
1244 let field_types: Vec<String> = fields
1245 .unnamed
1246 .iter()
1247 .map(|f| {
1248 let expandable_type = extract_expandable_type(&f.ty);
1251 if !is_primitive_type(expandable_type) {
1252 nested_types.push(expandable_type);
1253 }
1254 format_type_for_schema(&f.ty)
1255 })
1256 .collect();
1257
1258 let tuple_str = field_types.join(", ");
1259
1260 let example_values: Vec<String> = field_types
1262 .iter()
1263 .map(|type_str| generate_example_value_for_type(type_str))
1264 .collect();
1265 let example_tuple_str = example_values.join(", ");
1266
1267 if example_tuple.is_none() {
1269 if is_untagged || is_internally_tagged {
1270 example_tuple = Some(format!("[{}]", example_tuple_str));
1271 } else {
1272 example_tuple = Some(format!(
1273 "{{ \"{}\": [{}] }}",
1274 variant_value, example_tuple_str
1275 ));
1276 }
1277 }
1278
1279 let comment = if let Some(desc) = &prompt_attrs.description {
1280 format!(" // {}", desc)
1281 } else {
1282 let docs = extract_doc_comments(&variant.attrs);
1283 if !docs.is_empty() {
1284 format!(" // {}", docs)
1285 } else if is_untagged {
1286 format!(" // {}", variant_value)
1288 } else {
1289 String::new()
1290 }
1291 };
1292
1293 if is_untagged || is_internally_tagged {
1294 format!(" | [{}]{}", tuple_str, comment)
1297 } else {
1298 format!(
1300 " | {{ \"{}\": [{}] }}{}",
1301 variant_value, tuple_str, comment
1302 )
1303 }
1304 }
1305 };
1306
1307 variant_lines.push(variant_line);
1308
1309 if first_variant_name.is_none() {
1310 first_variant_name = Some(variant_value);
1311 }
1312 }
1313
1314 let mut lines = Vec::new();
1316
1317 if !enum_docs.is_empty() {
1319 lines.push("/**".to_string());
1320 lines.push(format!(" * {}", enum_docs));
1321 lines.push(" */".to_string());
1322 }
1323
1324 lines.push(format!("type {} =", enum_name));
1326
1327 for line in &variant_lines {
1329 lines.push(line.clone());
1330 }
1331
1332 if let Some(last) = lines.last_mut()
1334 && !last.ends_with(';')
1335 {
1336 last.push(';');
1337 }
1338
1339 let mut examples = Vec::new();
1341 if let Some(ex) = example_unit {
1342 examples.push(ex);
1343 }
1344 if let Some(ex) = example_struct {
1345 examples.push(ex);
1346 }
1347 if let Some(ex) = example_tuple {
1348 examples.push(ex);
1349 }
1350
1351 if !examples.is_empty() {
1352 lines.push("".to_string()); if examples.len() == 1 {
1354 lines.push(format!("Example value: {}", examples[0]));
1355 } else {
1356 lines.push("Example values:".to_string());
1357 for ex in examples {
1358 lines.push(format!(" {}", ex));
1359 }
1360 }
1361 }
1362
1363 let nested_type_tokens: Vec<_> = nested_types
1365 .iter()
1366 .map(|field_ty| {
1367 quote! {
1368 {
1369 let type_schema = <#field_ty as #crate_path::prompt::ToPrompt>::prompt_schema();
1370 if !type_schema.is_empty() {
1371 format!("\n\n{}", type_schema)
1372 } else {
1373 String::new()
1374 }
1375 }
1376 }
1377 })
1378 .collect();
1379
1380 let prompt_string = if nested_type_tokens.is_empty() {
1381 let lines_str = lines.join("\n");
1382 quote! { #lines_str.to_string() }
1383 } else {
1384 let lines_str = lines.join("\n");
1385 quote! {
1386 {
1387 let mut result = String::from(#lines_str);
1388
1389 let nested_schemas: Vec<String> = vec![#(#nested_type_tokens),*];
1391 let mut seen_schemas = std::collections::HashSet::<String>::new();
1392
1393 for schema in nested_schemas {
1394 if !schema.is_empty() && seen_schemas.insert(schema.clone()) {
1395 result.push_str(&schema);
1396 }
1397 }
1398
1399 result
1400 }
1401 }
1402 };
1403 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1404
1405 let mut match_arms = Vec::new();
1407 for variant in &data_enum.variants {
1408 let variant_name = &variant.ident;
1409 let variant_name_str = variant_name.to_string();
1410
1411 let prompt_attrs = parse_prompt_attributes(&variant.attrs);
1413
1414 let variant_value = if let Some(prompt_rename) = &prompt_attrs.rename {
1420 prompt_rename.clone()
1421 } else if let Some(serde_rename) = parse_serde_variant_rename(&variant.attrs) {
1422 serde_rename
1423 } else if let Some(rule) = rename_rule {
1424 rule.apply(&variant_name_str)
1425 } else {
1426 variant_name_str.clone()
1427 };
1428
1429 match &variant.fields {
1431 syn::Fields::Unit => {
1432 if prompt_attrs.skip {
1434 match_arms.push(quote! {
1435 Self::#variant_name => stringify!(#variant_name).to_string()
1436 });
1437 } else if let Some(desc) = &prompt_attrs.description {
1438 match_arms.push(quote! {
1439 Self::#variant_name => format!("{}: {}", #variant_value, #desc)
1440 });
1441 } else {
1442 let variant_docs = extract_doc_comments(&variant.attrs);
1443 if !variant_docs.is_empty() {
1444 match_arms.push(quote! {
1445 Self::#variant_name => format!("{}: {}", #variant_value, #variant_docs)
1446 });
1447 } else {
1448 match_arms.push(quote! {
1449 Self::#variant_name => #variant_value.to_string()
1450 });
1451 }
1452 }
1453 }
1454 syn::Fields::Named(fields) => {
1455 let field_bindings: Vec<_> = fields
1457 .named
1458 .iter()
1459 .map(|f| f.ident.as_ref().unwrap())
1460 .collect();
1461
1462 let field_displays: Vec<_> = fields
1463 .named
1464 .iter()
1465 .map(|f| {
1466 let field_name = f.ident.as_ref().unwrap();
1467 let field_name_str = field_name.to_string();
1468 quote! {
1469 format!("{}: {:?}", #field_name_str, #field_name)
1470 }
1471 })
1472 .collect();
1473
1474 let doc_or_desc = if let Some(desc) = &prompt_attrs.description {
1475 desc.clone()
1476 } else {
1477 let docs = extract_doc_comments(&variant.attrs);
1478 if !docs.is_empty() {
1479 docs
1480 } else {
1481 String::new()
1482 }
1483 };
1484
1485 if doc_or_desc.is_empty() {
1486 match_arms.push(quote! {
1487 Self::#variant_name { #(#field_bindings),* } => {
1488 let fields = vec![#(#field_displays),*];
1489 format!("{} {{ {} }}", #variant_value, fields.join(", "))
1490 }
1491 });
1492 } else {
1493 match_arms.push(quote! {
1494 Self::#variant_name { #(#field_bindings),* } => {
1495 let fields = vec![#(#field_displays),*];
1496 format!("{}: {} {{ {} }}", #variant_value, #doc_or_desc, fields.join(", "))
1497 }
1498 });
1499 }
1500 }
1501 syn::Fields::Unnamed(fields) => {
1502 let field_count = fields.unnamed.len();
1504 let field_bindings: Vec<_> = (0..field_count)
1505 .map(|i| {
1506 syn::Ident::new(
1507 &format!("field{}", i),
1508 proc_macro2::Span::call_site(),
1509 )
1510 })
1511 .collect();
1512
1513 let field_displays: Vec<_> = field_bindings
1514 .iter()
1515 .map(|field_name| {
1516 quote! {
1517 format!("{:?}", #field_name)
1518 }
1519 })
1520 .collect();
1521
1522 let doc_or_desc = if let Some(desc) = &prompt_attrs.description {
1523 desc.clone()
1524 } else {
1525 let docs = extract_doc_comments(&variant.attrs);
1526 if !docs.is_empty() {
1527 docs
1528 } else {
1529 String::new()
1530 }
1531 };
1532
1533 if doc_or_desc.is_empty() {
1534 match_arms.push(quote! {
1535 Self::#variant_name(#(#field_bindings),*) => {
1536 let fields = vec![#(#field_displays),*];
1537 format!("{}({})", #variant_value, fields.join(", "))
1538 }
1539 });
1540 } else {
1541 match_arms.push(quote! {
1542 Self::#variant_name(#(#field_bindings),*) => {
1543 let fields = vec![#(#field_displays),*];
1544 format!("{}: {}({})", #variant_value, #doc_or_desc, fields.join(", "))
1545 }
1546 });
1547 }
1548 }
1549 }
1550 }
1551
1552 let to_prompt_impl = if match_arms.is_empty() {
1553 quote! {
1555 fn to_prompt(&self) -> String {
1556 match *self {}
1557 }
1558 }
1559 } else {
1560 quote! {
1561 fn to_prompt(&self) -> String {
1562 match self {
1563 #(#match_arms),*
1564 }
1565 }
1566 }
1567 };
1568
1569 let expanded = quote! {
1570 impl #impl_generics #crate_path::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
1571 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1572 vec![#crate_path::prompt::PromptPart::Text(self.to_prompt())]
1573 }
1574
1575 #to_prompt_impl
1576
1577 fn prompt_schema() -> String {
1578 #prompt_string
1579 }
1580 }
1581 };
1582
1583 TokenStream::from(expanded)
1584 }
1585 Data::Struct(data_struct) => {
1586 let mut template_attr = None;
1588 let mut template_file_attr = None;
1589 let mut mode_attr = None;
1590 let mut validate_attr = false;
1591 let mut type_marker_attr = false;
1592
1593 for attr in &input.attrs {
1594 if attr.path().is_ident("prompt") {
1595 if let Ok(metas) =
1597 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1598 {
1599 for meta in metas {
1600 match meta {
1601 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1602 if let syn::Expr::Lit(expr_lit) = nv.value
1603 && let syn::Lit::Str(lit_str) = expr_lit.lit
1604 {
1605 template_attr = Some(lit_str.value());
1606 }
1607 }
1608 Meta::NameValue(nv) if nv.path.is_ident("template_file") => {
1609 if let syn::Expr::Lit(expr_lit) = nv.value
1610 && let syn::Lit::Str(lit_str) = expr_lit.lit
1611 {
1612 template_file_attr = Some(lit_str.value());
1613 }
1614 }
1615 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
1616 if let syn::Expr::Lit(expr_lit) = nv.value
1617 && let syn::Lit::Str(lit_str) = expr_lit.lit
1618 {
1619 mode_attr = Some(lit_str.value());
1620 }
1621 }
1622 Meta::NameValue(nv) if nv.path.is_ident("validate") => {
1623 if let syn::Expr::Lit(expr_lit) = nv.value
1624 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
1625 {
1626 validate_attr = lit_bool.value();
1627 }
1628 }
1629 Meta::NameValue(nv) if nv.path.is_ident("type_marker") => {
1630 if let syn::Expr::Lit(expr_lit) = nv.value
1631 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
1632 {
1633 type_marker_attr = lit_bool.value();
1634 }
1635 }
1636 Meta::Path(path) if path.is_ident("type_marker") => {
1637 type_marker_attr = true;
1639 }
1640 _ => {}
1641 }
1642 }
1643 }
1644 }
1645 }
1646
1647 if template_attr.is_some() && template_file_attr.is_some() {
1649 return syn::Error::new(
1650 input.ident.span(),
1651 "The `template` and `template_file` attributes are mutually exclusive. Please use only one.",
1652 ).to_compile_error().into();
1653 }
1654
1655 let template_str = if let Some(file_path) = template_file_attr {
1657 let mut full_path = None;
1661
1662 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
1664 let is_trybuild = manifest_dir.contains("target/tests/trybuild");
1666
1667 if !is_trybuild {
1668 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
1670 if candidate.exists() {
1671 full_path = Some(candidate);
1672 }
1673 } else {
1674 if let Some(target_pos) = manifest_dir.find("/target/tests/trybuild") {
1680 let workspace_root = &manifest_dir[..target_pos];
1681 let original_macros_dir = std::path::Path::new(workspace_root)
1683 .join("crates")
1684 .join("llm-toolkit-macros");
1685
1686 let candidate = original_macros_dir.join(&file_path);
1687 if candidate.exists() {
1688 full_path = Some(candidate);
1689 }
1690 }
1691 }
1692 }
1693
1694 if full_path.is_none() {
1696 let candidate = std::path::Path::new(&file_path).to_path_buf();
1697 if candidate.exists() {
1698 full_path = Some(candidate);
1699 }
1700 }
1701
1702 if full_path.is_none()
1705 && let Ok(current_dir) = std::env::current_dir()
1706 {
1707 let mut search_dir = current_dir.as_path();
1708 for _ in 0..10 {
1710 let macros_dir = search_dir.join("crates/llm-toolkit-macros");
1712 if macros_dir.exists() {
1713 let candidate = macros_dir.join(&file_path);
1714 if candidate.exists() {
1715 full_path = Some(candidate);
1716 break;
1717 }
1718 }
1719 let candidate = search_dir.join(&file_path);
1721 if candidate.exists() {
1722 full_path = Some(candidate);
1723 break;
1724 }
1725 if let Some(parent) = search_dir.parent() {
1726 search_dir = parent;
1727 } else {
1728 break;
1729 }
1730 }
1731 }
1732
1733 if full_path.is_none() {
1735 let mut error_msg = format!(
1737 "Template file '{}' not found at compile time.\n\nSearched in:",
1738 file_path
1739 );
1740
1741 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
1742 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
1743 error_msg.push_str(&format!("\n - {}", candidate.display()));
1744 }
1745
1746 if let Ok(current_dir) = std::env::current_dir() {
1747 let candidate = current_dir.join(&file_path);
1748 error_msg.push_str(&format!("\n - {}", candidate.display()));
1749 }
1750
1751 error_msg.push_str("\n\nPlease ensure:");
1752 error_msg.push_str("\n 1. The template file exists");
1753 error_msg.push_str("\n 2. The path is relative to CARGO_MANIFEST_DIR");
1754 error_msg.push_str("\n 3. There are no typos in the path");
1755
1756 return syn::Error::new(input.ident.span(), error_msg)
1757 .to_compile_error()
1758 .into();
1759 }
1760
1761 let final_path = full_path.unwrap();
1762
1763 match std::fs::read_to_string(&final_path) {
1765 Ok(content) => Some(content),
1766 Err(e) => {
1767 return syn::Error::new(
1768 input.ident.span(),
1769 format!(
1770 "Failed to read template file '{}': {}\n\nPath resolved to: {}",
1771 file_path,
1772 e,
1773 final_path.display()
1774 ),
1775 )
1776 .to_compile_error()
1777 .into();
1778 }
1779 }
1780 } else {
1781 template_attr
1782 };
1783
1784 if validate_attr && let Some(template) = &template_str {
1786 let mut env = minijinja::Environment::new();
1788 if let Err(e) = env.add_template("validation", template) {
1789 let warning_msg =
1791 format!("Template validation warning: Invalid Jinja syntax - {}", e);
1792 let warning_ident = syn::Ident::new(
1793 "TEMPLATE_VALIDATION_WARNING",
1794 proc_macro2::Span::call_site(),
1795 );
1796 let _warning_tokens = quote! {
1797 #[deprecated(note = #warning_msg)]
1798 const #warning_ident: () = ();
1799 let _ = #warning_ident;
1800 };
1801 eprintln!("cargo:warning={}", warning_msg);
1803 }
1804
1805 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1807 &fields.named
1808 } else {
1809 panic!("Template validation is only supported for structs with named fields.");
1810 };
1811
1812 let field_names: std::collections::HashSet<String> = fields
1813 .iter()
1814 .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
1815 .collect();
1816
1817 let placeholders = parse_template_placeholders_with_mode(template);
1819
1820 for (placeholder_name, _mode) in &placeholders {
1821 if placeholder_name != "self" && !field_names.contains(placeholder_name) {
1822 let warning_msg = format!(
1823 "Template validation warning: Variable '{}' used in template but not found in struct fields",
1824 placeholder_name
1825 );
1826 eprintln!("cargo:warning={}", warning_msg);
1827 }
1828 }
1829 }
1830
1831 let name = input.ident;
1832 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1833
1834 let struct_docs = extract_doc_comments(&input.attrs);
1836
1837 let is_mode_based =
1839 mode_attr.is_some() || (template_str.is_none() && struct_docs.contains("mode"));
1840
1841 let expanded = if is_mode_based || mode_attr.is_some() {
1842 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1844 &fields.named
1845 } else {
1846 panic!(
1847 "Mode-based prompt generation is only supported for structs with named fields."
1848 );
1849 };
1850
1851 let struct_name_str = name.to_string();
1852
1853 let has_default = input.attrs.iter().any(|attr| {
1855 if attr.path().is_ident("derive")
1856 && let Ok(meta_list) = attr.meta.require_list()
1857 {
1858 let tokens_str = meta_list.tokens.to_string();
1859 tokens_str.contains("Default")
1860 } else {
1861 false
1862 }
1863 });
1864
1865 let schema_parts = generate_schema_only_parts(
1876 &struct_name_str,
1877 &struct_docs,
1878 fields,
1879 &crate_path,
1880 type_marker_attr,
1881 );
1882
1883 let example_parts = generate_example_only_parts(fields, has_default, &crate_path);
1885
1886 quote! {
1887 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1888 fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<#crate_path::prompt::PromptPart> {
1889 match mode {
1890 "schema_only" => #schema_parts,
1891 "example_only" => #example_parts,
1892 "full" | _ => {
1893 let mut parts = Vec::new();
1895
1896 let schema_parts = #schema_parts;
1898 parts.extend(schema_parts);
1899
1900 parts.push(#crate_path::prompt::PromptPart::Text("\n### Example".to_string()));
1902 parts.push(#crate_path::prompt::PromptPart::Text(
1903 format!("Here is an example of a valid `{}` object:", #struct_name_str)
1904 ));
1905
1906 let example_parts = #example_parts;
1908 parts.extend(example_parts);
1909
1910 parts
1911 }
1912 }
1913 }
1914
1915 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1916 self.to_prompt_parts_with_mode("full")
1917 }
1918
1919 fn to_prompt(&self) -> String {
1920 self.to_prompt_parts()
1921 .into_iter()
1922 .filter_map(|part| match part {
1923 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1924 _ => None,
1925 })
1926 .collect::<Vec<_>>()
1927 .join("\n")
1928 }
1929
1930 fn prompt_schema() -> String {
1931 use std::sync::OnceLock;
1932 static SCHEMA_CACHE: OnceLock<String> = OnceLock::new();
1933
1934 SCHEMA_CACHE.get_or_init(|| {
1935 let schema_parts = #schema_parts;
1936 schema_parts
1937 .into_iter()
1938 .filter_map(|part| match part {
1939 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1940 _ => None,
1941 })
1942 .collect::<Vec<_>>()
1943 .join("\n")
1944 }).clone()
1945 }
1946 }
1947 }
1948 } else if let Some(template) = template_str {
1949 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1952 &fields.named
1953 } else {
1954 panic!(
1955 "Template prompt generation is only supported for structs with named fields."
1956 );
1957 };
1958
1959 let placeholders = parse_template_placeholders_with_mode(&template);
1961 let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
1963 mode.is_some()
1964 && fields
1965 .iter()
1966 .any(|f| f.ident.as_ref().unwrap() == field_name)
1967 });
1968
1969 let mut image_field_parts = Vec::new();
1970 for f in fields.iter() {
1971 let field_name = f.ident.as_ref().unwrap();
1972 let attrs = parse_field_prompt_attrs(&f.attrs);
1973
1974 if attrs.image {
1975 image_field_parts.push(quote! {
1977 parts.extend(self.#field_name.to_prompt_parts());
1978 });
1979 }
1980 }
1981
1982 if has_mode_syntax {
1984 let mut context_fields = Vec::new();
1986 let mut modified_template = template.clone();
1987
1988 for (field_name, mode_opt) in &placeholders {
1990 if let Some(mode) = mode_opt {
1991 let unique_key = format!("{}__{}", field_name, mode);
1993
1994 let pattern = format!("{{{{ {}:{} }}}}", field_name, mode);
1996 let replacement = format!("{{{{ {} }}}}", unique_key);
1997 modified_template = modified_template.replace(&pattern, &replacement);
1998
1999 let field_ident =
2001 syn::Ident::new(field_name, proc_macro2::Span::call_site());
2002
2003 context_fields.push(quote! {
2005 context.insert(
2006 #unique_key.to_string(),
2007 minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
2008 );
2009 });
2010 }
2011 }
2012
2013 for field in fields.iter() {
2015 let field_name = field.ident.as_ref().unwrap();
2016 let field_name_str = field_name.to_string();
2017
2018 let has_mode_entry = placeholders
2020 .iter()
2021 .any(|(name, mode)| name == &field_name_str && mode.is_some());
2022
2023 if !has_mode_entry {
2024 let is_primitive = match &field.ty {
2027 syn::Type::Path(type_path) => {
2028 if let Some(segment) = type_path.path.segments.last() {
2029 let type_name = segment.ident.to_string();
2030 matches!(
2031 type_name.as_str(),
2032 "String"
2033 | "str"
2034 | "i8"
2035 | "i16"
2036 | "i32"
2037 | "i64"
2038 | "i128"
2039 | "isize"
2040 | "u8"
2041 | "u16"
2042 | "u32"
2043 | "u64"
2044 | "u128"
2045 | "usize"
2046 | "f32"
2047 | "f64"
2048 | "bool"
2049 | "char"
2050 )
2051 } else {
2052 false
2053 }
2054 }
2055 _ => false,
2056 };
2057
2058 if is_primitive {
2059 context_fields.push(quote! {
2060 context.insert(
2061 #field_name_str.to_string(),
2062 minijinja::Value::from_serialize(&self.#field_name)
2063 );
2064 });
2065 } else {
2066 context_fields.push(quote! {
2068 context.insert(
2069 #field_name_str.to_string(),
2070 minijinja::Value::from(self.#field_name.to_prompt())
2071 );
2072 });
2073 }
2074 }
2075 }
2076
2077 quote! {
2078 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
2079 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
2080 let mut parts = Vec::new();
2081
2082 #(#image_field_parts)*
2084
2085 let text = {
2087 let mut env = minijinja::Environment::new();
2088 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
2089 panic!("Failed to parse template: {}", e)
2090 });
2091
2092 let tmpl = env.get_template("prompt").unwrap();
2093
2094 let mut context = std::collections::HashMap::new();
2095 #(#context_fields)*
2096
2097 tmpl.render(context).unwrap_or_else(|e| {
2098 format!("Failed to render prompt: {}", e)
2099 })
2100 };
2101
2102 if !text.is_empty() {
2103 parts.push(#crate_path::prompt::PromptPart::Text(text));
2104 }
2105
2106 parts
2107 }
2108
2109 fn to_prompt(&self) -> String {
2110 let mut env = minijinja::Environment::new();
2112 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
2113 panic!("Failed to parse template: {}", e)
2114 });
2115
2116 let tmpl = env.get_template("prompt").unwrap();
2117
2118 let mut context = std::collections::HashMap::new();
2119 #(#context_fields)*
2120
2121 tmpl.render(context).unwrap_or_else(|e| {
2122 format!("Failed to render prompt: {}", e)
2123 })
2124 }
2125
2126 fn prompt_schema() -> String {
2127 String::new() }
2129 }
2130 }
2131 } else {
2132 quote! {
2134 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
2135 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
2136 let mut parts = Vec::new();
2137
2138 #(#image_field_parts)*
2140
2141 let text = #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
2143 format!("Failed to render prompt: {}", e)
2144 });
2145 if !text.is_empty() {
2146 parts.push(#crate_path::prompt::PromptPart::Text(text));
2147 }
2148
2149 parts
2150 }
2151
2152 fn to_prompt(&self) -> String {
2153 #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
2154 format!("Failed to render prompt: {}", e)
2155 })
2156 }
2157
2158 fn prompt_schema() -> String {
2159 String::new() }
2161 }
2162 }
2163 }
2164 } else {
2165 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
2168 &fields.named
2169 } else {
2170 panic!(
2171 "Default prompt generation is only supported for structs with named fields."
2172 );
2173 };
2174
2175 let mut text_field_parts = Vec::new();
2177 let mut image_field_parts = Vec::new();
2178
2179 for f in fields.iter() {
2180 let field_name = f.ident.as_ref().unwrap();
2181 let attrs = parse_field_prompt_attrs(&f.attrs);
2182
2183 if attrs.skip {
2185 continue;
2186 }
2187
2188 if attrs.image {
2189 image_field_parts.push(quote! {
2191 parts.extend(self.#field_name.to_prompt_parts());
2192 });
2193 } else {
2194 let key = if let Some(rename) = attrs.rename {
2200 rename
2201 } else {
2202 let doc_comment = extract_doc_comments(&f.attrs);
2203 if !doc_comment.is_empty() {
2204 doc_comment
2205 } else {
2206 field_name.to_string()
2207 }
2208 };
2209
2210 let value_expr = if let Some(format_with) = attrs.format_with {
2212 let func_path: syn::Path =
2214 syn::parse_str(&format_with).unwrap_or_else(|_| {
2215 panic!("Invalid function path: {}", format_with)
2216 });
2217 quote! { #func_path(&self.#field_name) }
2218 } else {
2219 quote! { self.#field_name.to_prompt() }
2220 };
2221
2222 text_field_parts.push(quote! {
2223 text_parts.push(format!("{}: {}", #key, #value_expr));
2224 });
2225 }
2226 }
2227
2228 let struct_name_str = name.to_string();
2230 let schema_parts = generate_schema_only_parts(
2231 &struct_name_str,
2232 &struct_docs,
2233 fields,
2234 &crate_path,
2235 false, );
2237
2238 quote! {
2240 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
2241 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
2242 let mut parts = Vec::new();
2243
2244 #(#image_field_parts)*
2246
2247 let mut text_parts = Vec::new();
2249 #(#text_field_parts)*
2250
2251 if !text_parts.is_empty() {
2252 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
2253 }
2254
2255 parts
2256 }
2257
2258 fn to_prompt(&self) -> String {
2259 let mut text_parts = Vec::new();
2260 #(#text_field_parts)*
2261 text_parts.join("\n")
2262 }
2263
2264 fn prompt_schema() -> String {
2265 use std::sync::OnceLock;
2266 static SCHEMA_CACHE: OnceLock<String> = OnceLock::new();
2267
2268 SCHEMA_CACHE.get_or_init(|| {
2269 let schema_parts = #schema_parts;
2270 schema_parts
2271 .into_iter()
2272 .filter_map(|part| match part {
2273 #crate_path::prompt::PromptPart::Text(text) => Some(text),
2274 _ => None,
2275 })
2276 .collect::<Vec<_>>()
2277 .join("\n")
2278 }).clone()
2279 }
2280 }
2281 }
2282 };
2283
2284 TokenStream::from(expanded)
2285 }
2286 Data::Union(_) => {
2287 panic!("`#[derive(ToPrompt)]` is not supported for unions");
2288 }
2289 }
2290}
2291
2292#[derive(Debug, Clone)]
2294struct TargetInfo {
2295 name: String,
2296 template: Option<String>,
2297 field_configs: std::collections::HashMap<String, FieldTargetConfig>,
2298}
2299
2300#[derive(Debug, Clone, Default)]
2302struct FieldTargetConfig {
2303 skip: bool,
2304 rename: Option<String>,
2305 format_with: Option<String>,
2306 image: bool,
2307 include_only: bool, }
2309
2310fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
2312 let mut configs = Vec::new();
2313
2314 for attr in attrs {
2315 if attr.path().is_ident("prompt_for")
2316 && let Ok(meta_list) = attr.meta.require_list()
2317 {
2318 if meta_list.tokens.to_string() == "skip" {
2320 let config = FieldTargetConfig {
2322 skip: true,
2323 ..Default::default()
2324 };
2325 configs.push(("*".to_string(), config));
2326 } else if let Ok(metas) =
2327 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2328 {
2329 let mut target_name = None;
2330 let mut config = FieldTargetConfig::default();
2331
2332 for meta in metas {
2333 match meta {
2334 Meta::NameValue(nv) if nv.path.is_ident("name") => {
2335 if let syn::Expr::Lit(syn::ExprLit {
2336 lit: syn::Lit::Str(lit_str),
2337 ..
2338 }) = nv.value
2339 {
2340 target_name = Some(lit_str.value());
2341 }
2342 }
2343 Meta::Path(path) if path.is_ident("skip") => {
2344 config.skip = true;
2345 }
2346 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
2347 if let syn::Expr::Lit(syn::ExprLit {
2348 lit: syn::Lit::Str(lit_str),
2349 ..
2350 }) = nv.value
2351 {
2352 config.rename = Some(lit_str.value());
2353 }
2354 }
2355 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
2356 if let syn::Expr::Lit(syn::ExprLit {
2357 lit: syn::Lit::Str(lit_str),
2358 ..
2359 }) = nv.value
2360 {
2361 config.format_with = Some(lit_str.value());
2362 }
2363 }
2364 Meta::Path(path) if path.is_ident("image") => {
2365 config.image = true;
2366 }
2367 _ => {}
2368 }
2369 }
2370
2371 if let Some(name) = target_name {
2372 config.include_only = true;
2373 configs.push((name, config));
2374 }
2375 }
2376 }
2377 }
2378
2379 configs
2380}
2381
2382fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
2384 let mut targets = Vec::new();
2385
2386 for attr in attrs {
2387 if attr.path().is_ident("prompt_for")
2388 && let Ok(meta_list) = attr.meta.require_list()
2389 && let Ok(metas) =
2390 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2391 {
2392 let mut target_name = None;
2393 let mut template = None;
2394
2395 for meta in metas {
2396 match meta {
2397 Meta::NameValue(nv) if nv.path.is_ident("name") => {
2398 if let syn::Expr::Lit(syn::ExprLit {
2399 lit: syn::Lit::Str(lit_str),
2400 ..
2401 }) = nv.value
2402 {
2403 target_name = Some(lit_str.value());
2404 }
2405 }
2406 Meta::NameValue(nv) if nv.path.is_ident("template") => {
2407 if let syn::Expr::Lit(syn::ExprLit {
2408 lit: syn::Lit::Str(lit_str),
2409 ..
2410 }) = nv.value
2411 {
2412 template = Some(lit_str.value());
2413 }
2414 }
2415 _ => {}
2416 }
2417 }
2418
2419 if let Some(name) = target_name {
2420 targets.push(TargetInfo {
2421 name,
2422 template,
2423 field_configs: std::collections::HashMap::new(),
2424 });
2425 }
2426 }
2427 }
2428
2429 targets
2430}
2431
2432#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
2433pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
2434 let input = parse_macro_input!(input as DeriveInput);
2435
2436 let found_crate =
2437 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2438 let crate_path = match found_crate {
2439 FoundCrate::Itself => {
2440 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2442 quote!(::#ident)
2443 }
2444 FoundCrate::Name(name) => {
2445 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2446 quote!(::#ident)
2447 }
2448 };
2449
2450 let data_struct = match &input.data {
2452 Data::Struct(data) => data,
2453 _ => {
2454 return syn::Error::new(
2455 input.ident.span(),
2456 "`#[derive(ToPromptSet)]` is only supported for structs",
2457 )
2458 .to_compile_error()
2459 .into();
2460 }
2461 };
2462
2463 let fields = match &data_struct.fields {
2464 syn::Fields::Named(fields) => &fields.named,
2465 _ => {
2466 return syn::Error::new(
2467 input.ident.span(),
2468 "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
2469 )
2470 .to_compile_error()
2471 .into();
2472 }
2473 };
2474
2475 let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
2477
2478 for field in fields.iter() {
2480 let field_name = field.ident.as_ref().unwrap().to_string();
2481 let field_configs = parse_prompt_for_attrs(&field.attrs);
2482
2483 for (target_name, config) in field_configs {
2484 if target_name == "*" {
2485 for target in &mut targets {
2487 target
2488 .field_configs
2489 .entry(field_name.clone())
2490 .or_insert_with(FieldTargetConfig::default)
2491 .skip = config.skip;
2492 }
2493 } else {
2494 let target_exists = targets.iter().any(|t| t.name == target_name);
2496 if !target_exists {
2497 targets.push(TargetInfo {
2499 name: target_name.clone(),
2500 template: None,
2501 field_configs: std::collections::HashMap::new(),
2502 });
2503 }
2504
2505 let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
2506
2507 target.field_configs.insert(field_name.clone(), config);
2508 }
2509 }
2510 }
2511
2512 let mut match_arms = Vec::new();
2514
2515 for target in &targets {
2516 let target_name = &target.name;
2517
2518 if let Some(template_str) = &target.template {
2519 let mut image_parts = Vec::new();
2521
2522 for field in fields.iter() {
2523 let field_name = field.ident.as_ref().unwrap();
2524 let field_name_str = field_name.to_string();
2525
2526 if let Some(config) = target.field_configs.get(&field_name_str)
2527 && config.image
2528 {
2529 image_parts.push(quote! {
2530 parts.extend(self.#field_name.to_prompt_parts());
2531 });
2532 }
2533 }
2534
2535 match_arms.push(quote! {
2536 #target_name => {
2537 let mut parts = Vec::new();
2538
2539 #(#image_parts)*
2540
2541 let text = #crate_path::prompt::render_prompt(#template_str, self)
2542 .map_err(|e| #crate_path::prompt::PromptSetError::RenderFailed {
2543 target: #target_name.to_string(),
2544 source: e,
2545 })?;
2546
2547 if !text.is_empty() {
2548 parts.push(#crate_path::prompt::PromptPart::Text(text));
2549 }
2550
2551 Ok(parts)
2552 }
2553 });
2554 } else {
2555 let mut text_field_parts = Vec::new();
2557 let mut image_field_parts = Vec::new();
2558
2559 for field in fields.iter() {
2560 let field_name = field.ident.as_ref().unwrap();
2561 let field_name_str = field_name.to_string();
2562
2563 let config = target.field_configs.get(&field_name_str);
2565
2566 if let Some(cfg) = config
2568 && cfg.skip
2569 {
2570 continue;
2571 }
2572
2573 let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
2577 let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
2578 .iter()
2579 .any(|(name, _)| name != "*");
2580
2581 if has_any_target_specific_config && !is_explicitly_for_this_target {
2582 continue;
2583 }
2584
2585 if let Some(cfg) = config {
2586 if cfg.image {
2587 image_field_parts.push(quote! {
2588 parts.extend(self.#field_name.to_prompt_parts());
2589 });
2590 } else {
2591 let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
2592
2593 let value_expr = if let Some(format_with) = &cfg.format_with {
2594 match syn::parse_str::<syn::Path>(format_with) {
2596 Ok(func_path) => quote! { #func_path(&self.#field_name) },
2597 Err(_) => {
2598 let error_msg = format!(
2600 "Invalid function path in format_with: '{}'",
2601 format_with
2602 );
2603 quote! {
2604 compile_error!(#error_msg);
2605 String::new()
2606 }
2607 }
2608 }
2609 } else {
2610 quote! { self.#field_name.to_prompt() }
2611 };
2612
2613 text_field_parts.push(quote! {
2614 text_parts.push(format!("{}: {}", #key, #value_expr));
2615 });
2616 }
2617 } else {
2618 text_field_parts.push(quote! {
2620 text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
2621 });
2622 }
2623 }
2624
2625 match_arms.push(quote! {
2626 #target_name => {
2627 let mut parts = Vec::new();
2628
2629 #(#image_field_parts)*
2630
2631 let mut text_parts = Vec::new();
2632 #(#text_field_parts)*
2633
2634 if !text_parts.is_empty() {
2635 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
2636 }
2637
2638 Ok(parts)
2639 }
2640 });
2641 }
2642 }
2643
2644 let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
2646
2647 match_arms.push(quote! {
2649 _ => {
2650 let available = vec![#(#target_names.to_string()),*];
2651 Err(#crate_path::prompt::PromptSetError::TargetNotFound {
2652 target: target.to_string(),
2653 available,
2654 })
2655 }
2656 });
2657
2658 let struct_name = &input.ident;
2659 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2660
2661 let expanded = quote! {
2662 impl #impl_generics #crate_path::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
2663 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<#crate_path::prompt::PromptPart>, #crate_path::prompt::PromptSetError> {
2664 match target {
2665 #(#match_arms)*
2666 }
2667 }
2668 }
2669 };
2670
2671 TokenStream::from(expanded)
2672}
2673
2674struct TypeList {
2676 types: Punctuated<syn::Type, Token![,]>,
2677}
2678
2679impl Parse for TypeList {
2680 fn parse(input: ParseStream) -> syn::Result<Self> {
2681 Ok(TypeList {
2682 types: Punctuated::parse_terminated(input)?,
2683 })
2684 }
2685}
2686
2687#[proc_macro]
2711pub fn examples_section(input: TokenStream) -> TokenStream {
2712 let input = parse_macro_input!(input as TypeList);
2713
2714 let found_crate =
2715 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2716 let _crate_path = match found_crate {
2717 FoundCrate::Itself => quote!(crate),
2718 FoundCrate::Name(name) => {
2719 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2720 quote!(::#ident)
2721 }
2722 };
2723
2724 let mut type_sections = Vec::new();
2726
2727 for ty in input.types.iter() {
2728 let type_name_str = quote!(#ty).to_string();
2730
2731 type_sections.push(quote! {
2733 {
2734 let type_name = #type_name_str;
2735 let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
2736 format!("---\n#### `{}`\n{}", type_name, json_example)
2737 }
2738 });
2739 }
2740
2741 let expanded = quote! {
2743 {
2744 let mut sections = Vec::new();
2745 sections.push("---".to_string());
2746 sections.push("### Examples".to_string());
2747 sections.push("".to_string());
2748 sections.push("Here are examples of the data structures you should use.".to_string());
2749 sections.push("".to_string());
2750
2751 #(sections.push(#type_sections);)*
2752
2753 sections.push("---".to_string());
2754
2755 sections.join("\n")
2756 }
2757 };
2758
2759 TokenStream::from(expanded)
2760}
2761
2762fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
2764 for attr in attrs {
2765 if attr.path().is_ident("prompt_for")
2766 && let Ok(meta_list) = attr.meta.require_list()
2767 && let Ok(metas) =
2768 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2769 {
2770 let mut target_type = None;
2771 let mut template = None;
2772
2773 for meta in metas {
2774 match meta {
2775 Meta::NameValue(nv) if nv.path.is_ident("target") => {
2776 if let syn::Expr::Lit(syn::ExprLit {
2777 lit: syn::Lit::Str(lit_str),
2778 ..
2779 }) = nv.value
2780 {
2781 target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
2783 }
2784 }
2785 Meta::NameValue(nv) if nv.path.is_ident("template") => {
2786 if let syn::Expr::Lit(syn::ExprLit {
2787 lit: syn::Lit::Str(lit_str),
2788 ..
2789 }) = nv.value
2790 {
2791 template = Some(lit_str.value());
2792 }
2793 }
2794 _ => {}
2795 }
2796 }
2797
2798 if let (Some(target), Some(tmpl)) = (target_type, template) {
2799 return (target, tmpl);
2800 }
2801 }
2802 }
2803
2804 panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
2805}
2806
2807#[proc_macro_attribute]
2841pub fn define_intent(_attr: TokenStream, item: TokenStream) -> TokenStream {
2842 let input = parse_macro_input!(item as DeriveInput);
2843
2844 let found_crate =
2845 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2846 let crate_path = match found_crate {
2847 FoundCrate::Itself => {
2848 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2850 quote!(::#ident)
2851 }
2852 FoundCrate::Name(name) => {
2853 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2854 quote!(::#ident)
2855 }
2856 };
2857
2858 let enum_data = match &input.data {
2860 Data::Enum(data) => data,
2861 _ => {
2862 return syn::Error::new(
2863 input.ident.span(),
2864 "`#[define_intent]` can only be applied to enums",
2865 )
2866 .to_compile_error()
2867 .into();
2868 }
2869 };
2870
2871 let mut prompt_template = None;
2873 let mut extractor_tag = None;
2874 let mut mode = None;
2875
2876 for attr in &input.attrs {
2877 if attr.path().is_ident("intent")
2878 && let Ok(metas) =
2879 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2880 {
2881 for meta in metas {
2882 match meta {
2883 Meta::NameValue(nv) if nv.path.is_ident("prompt") => {
2884 if let syn::Expr::Lit(syn::ExprLit {
2885 lit: syn::Lit::Str(lit_str),
2886 ..
2887 }) = nv.value
2888 {
2889 prompt_template = Some(lit_str.value());
2890 }
2891 }
2892 Meta::NameValue(nv) if nv.path.is_ident("extractor_tag") => {
2893 if let syn::Expr::Lit(syn::ExprLit {
2894 lit: syn::Lit::Str(lit_str),
2895 ..
2896 }) = nv.value
2897 {
2898 extractor_tag = Some(lit_str.value());
2899 }
2900 }
2901 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
2902 if let syn::Expr::Lit(syn::ExprLit {
2903 lit: syn::Lit::Str(lit_str),
2904 ..
2905 }) = nv.value
2906 {
2907 mode = Some(lit_str.value());
2908 }
2909 }
2910 _ => {}
2911 }
2912 }
2913 }
2914 }
2915
2916 let mode = mode.unwrap_or_else(|| "single".to_string());
2918
2919 if mode != "single" && mode != "multi_tag" {
2921 return syn::Error::new(
2922 input.ident.span(),
2923 "`mode` must be either \"single\" or \"multi_tag\"",
2924 )
2925 .to_compile_error()
2926 .into();
2927 }
2928
2929 let prompt_template = match prompt_template {
2931 Some(p) => p,
2932 None => {
2933 return syn::Error::new(
2934 input.ident.span(),
2935 "`#[intent(...)]` attribute must include `prompt = \"...\"`",
2936 )
2937 .to_compile_error()
2938 .into();
2939 }
2940 };
2941
2942 if mode == "multi_tag" {
2944 let enum_name = &input.ident;
2945 let actions_doc = generate_multi_tag_actions_doc(&enum_data.variants);
2946 return generate_multi_tag_output(
2947 &input,
2948 enum_name,
2949 enum_data,
2950 prompt_template,
2951 actions_doc,
2952 );
2953 }
2954
2955 let extractor_tag = match extractor_tag {
2957 Some(t) => t,
2958 None => {
2959 return syn::Error::new(
2960 input.ident.span(),
2961 "`#[intent(...)]` attribute must include `extractor_tag = \"...\"`",
2962 )
2963 .to_compile_error()
2964 .into();
2965 }
2966 };
2967
2968 let enum_name = &input.ident;
2970 let enum_docs = extract_doc_comments(&input.attrs);
2971
2972 let mut intents_doc_lines = Vec::new();
2973
2974 if !enum_docs.is_empty() {
2976 intents_doc_lines.push(format!("{}: {}", enum_name, enum_docs));
2977 } else {
2978 intents_doc_lines.push(format!("{}:", enum_name));
2979 }
2980 intents_doc_lines.push(String::new()); intents_doc_lines.push("Possible values:".to_string());
2982
2983 for variant in &enum_data.variants {
2985 let variant_name = &variant.ident;
2986 let variant_docs = extract_doc_comments(&variant.attrs);
2987
2988 if !variant_docs.is_empty() {
2989 intents_doc_lines.push(format!("- {}: {}", variant_name, variant_docs));
2990 } else {
2991 intents_doc_lines.push(format!("- {}", variant_name));
2992 }
2993 }
2994
2995 let intents_doc_str = intents_doc_lines.join("\n");
2996
2997 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2999 let user_variables: Vec<String> = placeholders
3000 .iter()
3001 .filter_map(|(name, _)| {
3002 if name != "intents_doc" {
3003 Some(name.clone())
3004 } else {
3005 None
3006 }
3007 })
3008 .collect();
3009
3010 let enum_name_str = enum_name.to_string();
3012 let snake_case_name = to_snake_case(&enum_name_str);
3013 let function_name = syn::Ident::new(
3014 &format!("build_{}_prompt", snake_case_name),
3015 proc_macro2::Span::call_site(),
3016 );
3017
3018 let function_params: Vec<proc_macro2::TokenStream> = user_variables
3020 .iter()
3021 .map(|var| {
3022 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
3023 quote! { #ident: &str }
3024 })
3025 .collect();
3026
3027 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
3029 .iter()
3030 .map(|var| {
3031 let var_str = var.clone();
3032 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
3033 quote! {
3034 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
3035 }
3036 })
3037 .collect();
3038
3039 let converted_template = prompt_template.clone();
3041
3042 let extractor_name = syn::Ident::new(
3044 &format!("{}Extractor", enum_name),
3045 proc_macro2::Span::call_site(),
3046 );
3047
3048 let filtered_attrs: Vec<_> = input
3050 .attrs
3051 .iter()
3052 .filter(|attr| !attr.path().is_ident("intent"))
3053 .collect();
3054
3055 let vis = &input.vis;
3057 let generics = &input.generics;
3058 let variants = &enum_data.variants;
3059 let enum_output = quote! {
3060 #(#filtered_attrs)*
3061 #vis enum #enum_name #generics {
3062 #variants
3063 }
3064 };
3065
3066 let expanded = quote! {
3068 #enum_output
3070
3071 pub fn #function_name(#(#function_params),*) -> String {
3073 let mut env = minijinja::Environment::new();
3074 env.add_template("prompt", #converted_template)
3075 .expect("Failed to parse intent prompt template");
3076
3077 let tmpl = env.get_template("prompt").unwrap();
3078
3079 let mut __template_context = std::collections::HashMap::new();
3080
3081 __template_context.insert("intents_doc".to_string(), minijinja::Value::from(#intents_doc_str));
3083
3084 #(#context_insertions)*
3086
3087 tmpl.render(&__template_context)
3088 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
3089 }
3090
3091 pub struct #extractor_name;
3093
3094 impl #extractor_name {
3095 pub const EXTRACTOR_TAG: &'static str = #extractor_tag;
3096 }
3097
3098 impl #crate_path::intent::IntentExtractor<#enum_name> for #extractor_name {
3099 fn extract_intent(&self, response: &str) -> Result<#enum_name, #crate_path::intent::IntentExtractionError> {
3100 #crate_path::intent::extract_intent_from_response(response, Self::EXTRACTOR_TAG)
3102 }
3103 }
3104 };
3105
3106 TokenStream::from(expanded)
3107}
3108
3109fn to_snake_case(s: &str) -> String {
3111 let mut result = String::new();
3112 let mut prev_upper = false;
3113
3114 for (i, ch) in s.chars().enumerate() {
3115 if ch.is_uppercase() {
3116 if i > 0 && !prev_upper {
3117 result.push('_');
3118 }
3119 result.push(ch.to_lowercase().next().unwrap());
3120 prev_upper = true;
3121 } else {
3122 result.push(ch);
3123 prev_upper = false;
3124 }
3125 }
3126
3127 result
3128}
3129
3130#[derive(Debug, Default)]
3132struct ActionAttrs {
3133 tag: Option<String>,
3134}
3135
3136fn parse_action_attrs(attrs: &[syn::Attribute]) -> ActionAttrs {
3137 let mut result = ActionAttrs::default();
3138
3139 for attr in attrs {
3140 if attr.path().is_ident("action")
3141 && let Ok(meta_list) = attr.meta.require_list()
3142 && let Ok(metas) =
3143 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
3144 {
3145 for meta in metas {
3146 if let Meta::NameValue(nv) = meta
3147 && nv.path.is_ident("tag")
3148 && let syn::Expr::Lit(syn::ExprLit {
3149 lit: syn::Lit::Str(lit_str),
3150 ..
3151 }) = nv.value
3152 {
3153 result.tag = Some(lit_str.value());
3154 }
3155 }
3156 }
3157 }
3158
3159 result
3160}
3161
3162#[derive(Debug, Default)]
3164struct FieldActionAttrs {
3165 is_attribute: bool,
3166 is_inner_text: bool,
3167}
3168
3169fn parse_field_action_attrs(attrs: &[syn::Attribute]) -> FieldActionAttrs {
3170 let mut result = FieldActionAttrs::default();
3171
3172 for attr in attrs {
3173 if attr.path().is_ident("action")
3174 && let Ok(meta_list) = attr.meta.require_list()
3175 {
3176 let tokens_str = meta_list.tokens.to_string();
3177 if tokens_str == "attribute" {
3178 result.is_attribute = true;
3179 } else if tokens_str == "inner_text" {
3180 result.is_inner_text = true;
3181 }
3182 }
3183 }
3184
3185 result
3186}
3187
3188fn generate_multi_tag_actions_doc(
3190 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
3191) -> String {
3192 let mut doc_lines = Vec::new();
3193
3194 for variant in variants {
3195 let action_attrs = parse_action_attrs(&variant.attrs);
3196
3197 if let Some(tag) = action_attrs.tag {
3198 let variant_docs = extract_doc_comments(&variant.attrs);
3199
3200 match &variant.fields {
3201 syn::Fields::Unit => {
3202 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
3204 }
3205 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
3206 doc_lines.push(format!("- `<{}>...</{}>`: {}", tag, tag, variant_docs));
3208 }
3209 syn::Fields::Named(fields) => {
3210 let mut attrs_str = Vec::new();
3212 let mut has_inner_text = false;
3213
3214 for field in &fields.named {
3215 let field_name = field.ident.as_ref().unwrap();
3216 let field_attrs = parse_field_action_attrs(&field.attrs);
3217
3218 if field_attrs.is_attribute {
3219 attrs_str.push(format!("{}=\"...\"", field_name));
3220 } else if field_attrs.is_inner_text {
3221 has_inner_text = true;
3222 }
3223 }
3224
3225 let attrs_part = if !attrs_str.is_empty() {
3226 format!(" {}", attrs_str.join(" "))
3227 } else {
3228 String::new()
3229 };
3230
3231 if has_inner_text {
3232 doc_lines.push(format!(
3233 "- `<{}{}>...</{}>`: {}",
3234 tag, attrs_part, tag, variant_docs
3235 ));
3236 } else if !attrs_str.is_empty() {
3237 doc_lines.push(format!("- `<{}{} />`: {}", tag, attrs_part, variant_docs));
3238 } else {
3239 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
3240 }
3241
3242 for field in &fields.named {
3244 let field_name = field.ident.as_ref().unwrap();
3245 let field_attrs = parse_field_action_attrs(&field.attrs);
3246 let field_docs = extract_doc_comments(&field.attrs);
3247
3248 if field_attrs.is_attribute {
3249 doc_lines
3250 .push(format!(" - `{}` (attribute): {}", field_name, field_docs));
3251 } else if field_attrs.is_inner_text {
3252 doc_lines
3253 .push(format!(" - `{}` (inner_text): {}", field_name, field_docs));
3254 }
3255 }
3256 }
3257 _ => {
3258 }
3260 }
3261 }
3262 }
3263
3264 doc_lines.join("\n")
3265}
3266
3267fn generate_tags_regex(
3269 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
3270) -> String {
3271 let mut tag_names = Vec::new();
3272
3273 for variant in variants {
3274 let action_attrs = parse_action_attrs(&variant.attrs);
3275 if let Some(tag) = action_attrs.tag {
3276 tag_names.push(tag);
3277 }
3278 }
3279
3280 if tag_names.is_empty() {
3281 return String::new();
3282 }
3283
3284 let tags_pattern = tag_names.join("|");
3285 format!(
3288 r"(?is)<(?:{})\b[^>]*/>|<(?:{})\b[^>]*>.*?</(?:{})>",
3289 tags_pattern, tags_pattern, tags_pattern
3290 )
3291}
3292
3293fn generate_multi_tag_output(
3295 input: &DeriveInput,
3296 enum_name: &syn::Ident,
3297 enum_data: &syn::DataEnum,
3298 prompt_template: String,
3299 actions_doc: String,
3300) -> TokenStream {
3301 let found_crate =
3302 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3303 let crate_path = match found_crate {
3304 FoundCrate::Itself => {
3305 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3307 quote!(::#ident)
3308 }
3309 FoundCrate::Name(name) => {
3310 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3311 quote!(::#ident)
3312 }
3313 };
3314
3315 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
3317 let user_variables: Vec<String> = placeholders
3318 .iter()
3319 .filter_map(|(name, _)| {
3320 if name != "actions_doc" {
3321 Some(name.clone())
3322 } else {
3323 None
3324 }
3325 })
3326 .collect();
3327
3328 let enum_name_str = enum_name.to_string();
3330 let snake_case_name = to_snake_case(&enum_name_str);
3331 let function_name = syn::Ident::new(
3332 &format!("build_{}_prompt", snake_case_name),
3333 proc_macro2::Span::call_site(),
3334 );
3335
3336 let function_params: Vec<proc_macro2::TokenStream> = user_variables
3338 .iter()
3339 .map(|var| {
3340 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
3341 quote! { #ident: &str }
3342 })
3343 .collect();
3344
3345 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
3347 .iter()
3348 .map(|var| {
3349 let var_str = var.clone();
3350 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
3351 quote! {
3352 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
3353 }
3354 })
3355 .collect();
3356
3357 let extractor_name = syn::Ident::new(
3359 &format!("{}Extractor", enum_name),
3360 proc_macro2::Span::call_site(),
3361 );
3362
3363 let filtered_attrs: Vec<_> = input
3365 .attrs
3366 .iter()
3367 .filter(|attr| !attr.path().is_ident("intent"))
3368 .collect();
3369
3370 let filtered_variants: Vec<proc_macro2::TokenStream> = enum_data
3372 .variants
3373 .iter()
3374 .map(|variant| {
3375 let variant_name = &variant.ident;
3376 let variant_attrs: Vec<_> = variant
3377 .attrs
3378 .iter()
3379 .filter(|attr| !attr.path().is_ident("action"))
3380 .collect();
3381 let fields = &variant.fields;
3382
3383 let filtered_fields = match fields {
3385 syn::Fields::Named(named_fields) => {
3386 let filtered: Vec<_> = named_fields
3387 .named
3388 .iter()
3389 .map(|field| {
3390 let field_name = &field.ident;
3391 let field_type = &field.ty;
3392 let field_vis = &field.vis;
3393 let filtered_attrs: Vec<_> = field
3394 .attrs
3395 .iter()
3396 .filter(|attr| !attr.path().is_ident("action"))
3397 .collect();
3398 quote! {
3399 #(#filtered_attrs)*
3400 #field_vis #field_name: #field_type
3401 }
3402 })
3403 .collect();
3404 quote! { { #(#filtered,)* } }
3405 }
3406 syn::Fields::Unnamed(unnamed_fields) => {
3407 let types: Vec<_> = unnamed_fields
3408 .unnamed
3409 .iter()
3410 .map(|field| {
3411 let field_type = &field.ty;
3412 quote! { #field_type }
3413 })
3414 .collect();
3415 quote! { (#(#types),*) }
3416 }
3417 syn::Fields::Unit => quote! {},
3418 };
3419
3420 quote! {
3421 #(#variant_attrs)*
3422 #variant_name #filtered_fields
3423 }
3424 })
3425 .collect();
3426
3427 let vis = &input.vis;
3428 let generics = &input.generics;
3429
3430 let parsing_arms = generate_parsing_arms(&enum_data.variants, enum_name);
3432
3433 let tags_regex = generate_tags_regex(&enum_data.variants);
3435
3436 let expanded = quote! {
3437 #(#filtered_attrs)*
3439 #vis enum #enum_name #generics {
3440 #(#filtered_variants),*
3441 }
3442
3443 pub fn #function_name(#(#function_params),*) -> String {
3445 let mut env = minijinja::Environment::new();
3446 env.add_template("prompt", #prompt_template)
3447 .expect("Failed to parse intent prompt template");
3448
3449 let tmpl = env.get_template("prompt").unwrap();
3450
3451 let mut __template_context = std::collections::HashMap::new();
3452
3453 __template_context.insert("actions_doc".to_string(), minijinja::Value::from(#actions_doc));
3455
3456 #(#context_insertions)*
3458
3459 tmpl.render(&__template_context)
3460 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
3461 }
3462
3463 pub struct #extractor_name;
3465
3466 impl #extractor_name {
3467 fn parse_single_action(&self, text: &str) -> Option<#enum_name> {
3468 use ::quick_xml::events::Event;
3469 use ::quick_xml::Reader;
3470
3471 let mut actions = Vec::new();
3472 let mut reader = Reader::from_str(text);
3473 reader.config_mut().trim_text(true);
3474
3475 let mut buf = Vec::new();
3476
3477 loop {
3478 match reader.read_event_into(&mut buf) {
3479 Ok(Event::Start(e)) => {
3480 let owned_e = e.into_owned();
3481 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
3482 let is_empty = false;
3483
3484 #parsing_arms
3485 }
3486 Ok(Event::Empty(e)) => {
3487 let owned_e = e.into_owned();
3488 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
3489 let is_empty = true;
3490
3491 #parsing_arms
3492 }
3493 Ok(Event::Eof) => break,
3494 Err(_) => {
3495 break;
3497 }
3498 _ => {}
3499 }
3500 buf.clear();
3501 }
3502
3503 actions.into_iter().next()
3504 }
3505
3506 pub fn extract_actions(&self, text: &str) -> Result<Vec<#enum_name>, #crate_path::intent::IntentError> {
3507 use ::quick_xml::events::Event;
3508 use ::quick_xml::Reader;
3509
3510 let mut actions = Vec::new();
3511 let mut reader = Reader::from_str(text);
3512 reader.config_mut().trim_text(true);
3513
3514 let mut buf = Vec::new();
3515
3516 loop {
3517 match reader.read_event_into(&mut buf) {
3518 Ok(Event::Start(e)) => {
3519 let owned_e = e.into_owned();
3520 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
3521 let is_empty = false;
3522
3523 #parsing_arms
3524 }
3525 Ok(Event::Empty(e)) => {
3526 let owned_e = e.into_owned();
3527 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
3528 let is_empty = true;
3529
3530 #parsing_arms
3531 }
3532 Ok(Event::Eof) => break,
3533 Err(_) => {
3534 break;
3536 }
3537 _ => {}
3538 }
3539 buf.clear();
3540 }
3541
3542 Ok(actions)
3543 }
3544
3545 pub fn transform_actions<F>(&self, text: &str, mut transformer: F) -> String
3546 where
3547 F: FnMut(#enum_name) -> String,
3548 {
3549 use ::regex::Regex;
3550
3551 let regex_pattern = #tags_regex;
3552 if regex_pattern.is_empty() {
3553 return text.to_string();
3554 }
3555
3556 let re = Regex::new(®ex_pattern).unwrap_or_else(|e| {
3557 panic!("Failed to compile regex for action tags: {}", e);
3558 });
3559
3560 re.replace_all(text, |caps: &::regex::Captures| {
3561 let matched = caps.get(0).map(|m| m.as_str()).unwrap_or("");
3562
3563 if let Some(action) = self.parse_single_action(matched) {
3565 transformer(action)
3566 } else {
3567 matched.to_string()
3569 }
3570 }).to_string()
3571 }
3572
3573 pub fn strip_actions(&self, text: &str) -> String {
3574 self.transform_actions(text, |_| String::new())
3575 }
3576 }
3577 };
3578
3579 TokenStream::from(expanded)
3580}
3581
3582fn generate_parsing_arms(
3584 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
3585 enum_name: &syn::Ident,
3586) -> proc_macro2::TokenStream {
3587 let mut arms = Vec::new();
3588
3589 for variant in variants {
3590 let variant_name = &variant.ident;
3591 let action_attrs = parse_action_attrs(&variant.attrs);
3592
3593 if let Some(tag) = action_attrs.tag {
3594 match &variant.fields {
3595 syn::Fields::Unit => {
3596 arms.push(quote! {
3598 if &tag_name == #tag {
3599 actions.push(#enum_name::#variant_name);
3600 }
3601 });
3602 }
3603 syn::Fields::Unnamed(_fields) => {
3604 arms.push(quote! {
3606 if &tag_name == #tag && !is_empty {
3607 match reader.read_text(owned_e.name()) {
3609 Ok(text) => {
3610 actions.push(#enum_name::#variant_name(text.to_string()));
3611 }
3612 Err(_) => {
3613 actions.push(#enum_name::#variant_name(String::new()));
3615 }
3616 }
3617 }
3618 });
3619 }
3620 syn::Fields::Named(fields) => {
3621 let mut field_names = Vec::new();
3623 let mut has_inner_text_field = None;
3624
3625 for field in &fields.named {
3626 let field_name = field.ident.as_ref().unwrap();
3627 let field_attrs = parse_field_action_attrs(&field.attrs);
3628
3629 if field_attrs.is_attribute {
3630 field_names.push(field_name.clone());
3631 } else if field_attrs.is_inner_text {
3632 has_inner_text_field = Some(field_name.clone());
3633 }
3634 }
3635
3636 if let Some(inner_text_field) = has_inner_text_field {
3637 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
3640 quote! {
3641 let mut #field_name = String::new();
3642 for attr in owned_e.attributes() {
3643 if let Ok(attr) = attr {
3644 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
3645 #field_name = String::from_utf8_lossy(&attr.value).to_string();
3646 break;
3647 }
3648 }
3649 }
3650 }
3651 }).collect();
3652
3653 arms.push(quote! {
3654 if &tag_name == #tag {
3655 #(#attr_extractions)*
3656
3657 if is_empty {
3659 let #inner_text_field = String::new();
3660 actions.push(#enum_name::#variant_name {
3661 #(#field_names,)*
3662 #inner_text_field,
3663 });
3664 } else {
3665 match reader.read_text(owned_e.name()) {
3667 Ok(text) => {
3668 let #inner_text_field = text.to_string();
3669 actions.push(#enum_name::#variant_name {
3670 #(#field_names,)*
3671 #inner_text_field,
3672 });
3673 }
3674 Err(_) => {
3675 let #inner_text_field = String::new();
3677 actions.push(#enum_name::#variant_name {
3678 #(#field_names,)*
3679 #inner_text_field,
3680 });
3681 }
3682 }
3683 }
3684 }
3685 });
3686 } else {
3687 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
3689 quote! {
3690 let mut #field_name = String::new();
3691 for attr in owned_e.attributes() {
3692 if let Ok(attr) = attr {
3693 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
3694 #field_name = String::from_utf8_lossy(&attr.value).to_string();
3695 break;
3696 }
3697 }
3698 }
3699 }
3700 }).collect();
3701
3702 arms.push(quote! {
3703 if &tag_name == #tag {
3704 #(#attr_extractions)*
3705 actions.push(#enum_name::#variant_name {
3706 #(#field_names),*
3707 });
3708 }
3709 });
3710 }
3711 }
3712 }
3713 }
3714 }
3715
3716 quote! {
3717 #(#arms)*
3718 }
3719}
3720
3721#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
3723pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
3724 let input = parse_macro_input!(input as DeriveInput);
3725
3726 let found_crate =
3727 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3728 let crate_path = match found_crate {
3729 FoundCrate::Itself => {
3730 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3732 quote!(::#ident)
3733 }
3734 FoundCrate::Name(name) => {
3735 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3736 quote!(::#ident)
3737 }
3738 };
3739
3740 let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
3742
3743 let struct_name = &input.ident;
3744 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
3745
3746 let placeholders = parse_template_placeholders_with_mode(&template);
3748
3749 let mut converted_template = template.clone();
3751 let mut context_fields = Vec::new();
3752
3753 let fields = match &input.data {
3755 Data::Struct(data_struct) => match &data_struct.fields {
3756 syn::Fields::Named(fields) => &fields.named,
3757 _ => panic!("ToPromptFor is only supported for structs with named fields"),
3758 },
3759 _ => panic!("ToPromptFor is only supported for structs"),
3760 };
3761
3762 let has_mode_support = input.attrs.iter().any(|attr| {
3764 if attr.path().is_ident("prompt")
3765 && let Ok(metas) =
3766 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
3767 {
3768 for meta in metas {
3769 if let Meta::NameValue(nv) = meta
3770 && nv.path.is_ident("mode")
3771 {
3772 return true;
3773 }
3774 }
3775 }
3776 false
3777 });
3778
3779 for (placeholder_name, mode_opt) in &placeholders {
3781 if placeholder_name == "self" {
3782 if let Some(specific_mode) = mode_opt {
3783 let unique_key = format!("self__{}", specific_mode);
3785
3786 let pattern = format!("{{{{ self:{} }}}}", specific_mode);
3788 let replacement = format!("{{{{ {} }}}}", unique_key);
3789 converted_template = converted_template.replace(&pattern, &replacement);
3790
3791 context_fields.push(quote! {
3793 context.insert(
3794 #unique_key.to_string(),
3795 minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
3796 );
3797 });
3798 } else {
3799 if has_mode_support {
3802 context_fields.push(quote! {
3804 context.insert(
3805 "self".to_string(),
3806 minijinja::Value::from(self.to_prompt_with_mode(mode))
3807 );
3808 });
3809 } else {
3810 context_fields.push(quote! {
3812 context.insert(
3813 "self".to_string(),
3814 minijinja::Value::from(self.to_prompt())
3815 );
3816 });
3817 }
3818 }
3819 } else {
3820 let field_exists = fields.iter().any(|f| {
3823 f.ident
3824 .as_ref()
3825 .is_some_and(|ident| ident == placeholder_name)
3826 });
3827
3828 if field_exists {
3829 let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
3830
3831 context_fields.push(quote! {
3835 context.insert(
3836 #placeholder_name.to_string(),
3837 minijinja::Value::from_serialize(&self.#field_ident)
3838 );
3839 });
3840 }
3841 }
3843 }
3844
3845 let expanded = quote! {
3846 impl #impl_generics #crate_path::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
3847 where
3848 #target_type: serde::Serialize,
3849 {
3850 fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
3851 let mut env = minijinja::Environment::new();
3853 env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
3854 panic!("Failed to parse template: {}", e)
3855 });
3856
3857 let tmpl = env.get_template("prompt").unwrap();
3858
3859 let mut context = std::collections::HashMap::new();
3861 context.insert(
3863 "self".to_string(),
3864 minijinja::Value::from_serialize(self)
3865 );
3866 context.insert(
3868 "target".to_string(),
3869 minijinja::Value::from_serialize(target)
3870 );
3871 #(#context_fields)*
3872
3873 tmpl.render(context).unwrap_or_else(|e| {
3875 format!("Failed to render prompt: {}", e)
3876 })
3877 }
3878 }
3879 };
3880
3881 TokenStream::from(expanded)
3882}
3883
3884struct AgentAttrs {
3890 expertise: Option<String>,
3891 output: Option<syn::Type>,
3892 backend: Option<String>,
3893 model: Option<String>,
3894 inner: Option<String>,
3895 default_inner: Option<String>,
3896 max_retries: Option<u32>,
3897 profile: Option<String>,
3898}
3899
3900impl Parse for AgentAttrs {
3901 fn parse(input: ParseStream) -> syn::Result<Self> {
3902 let mut expertise = None;
3903 let mut output = None;
3904 let mut backend = None;
3905 let mut model = None;
3906 let mut inner = None;
3907 let mut default_inner = None;
3908 let mut max_retries = None;
3909 let mut profile = None;
3910
3911 let pairs = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
3912
3913 for meta in pairs {
3914 match meta {
3915 Meta::NameValue(nv) if nv.path.is_ident("expertise") => {
3916 if let syn::Expr::Lit(syn::ExprLit {
3917 lit: syn::Lit::Str(lit_str),
3918 ..
3919 }) = &nv.value
3920 {
3921 expertise = Some(lit_str.value());
3922 }
3923 }
3924 Meta::NameValue(nv) if nv.path.is_ident("output") => {
3925 if let syn::Expr::Lit(syn::ExprLit {
3926 lit: syn::Lit::Str(lit_str),
3927 ..
3928 }) = &nv.value
3929 {
3930 let ty: syn::Type = syn::parse_str(&lit_str.value())?;
3931 output = Some(ty);
3932 }
3933 }
3934 Meta::NameValue(nv) if nv.path.is_ident("backend") => {
3935 if let syn::Expr::Lit(syn::ExprLit {
3936 lit: syn::Lit::Str(lit_str),
3937 ..
3938 }) = &nv.value
3939 {
3940 backend = Some(lit_str.value());
3941 }
3942 }
3943 Meta::NameValue(nv) if nv.path.is_ident("model") => {
3944 if let syn::Expr::Lit(syn::ExprLit {
3945 lit: syn::Lit::Str(lit_str),
3946 ..
3947 }) = &nv.value
3948 {
3949 model = Some(lit_str.value());
3950 }
3951 }
3952 Meta::NameValue(nv) if nv.path.is_ident("inner") => {
3953 if let syn::Expr::Lit(syn::ExprLit {
3954 lit: syn::Lit::Str(lit_str),
3955 ..
3956 }) = &nv.value
3957 {
3958 inner = Some(lit_str.value());
3959 }
3960 }
3961 Meta::NameValue(nv) if nv.path.is_ident("default_inner") => {
3962 if let syn::Expr::Lit(syn::ExprLit {
3963 lit: syn::Lit::Str(lit_str),
3964 ..
3965 }) = &nv.value
3966 {
3967 default_inner = Some(lit_str.value());
3968 }
3969 }
3970 Meta::NameValue(nv) if nv.path.is_ident("max_retries") => {
3971 if let syn::Expr::Lit(syn::ExprLit {
3972 lit: syn::Lit::Int(lit_int),
3973 ..
3974 }) = &nv.value
3975 {
3976 max_retries = Some(lit_int.base10_parse()?);
3977 }
3978 }
3979 Meta::NameValue(nv) if nv.path.is_ident("profile") => {
3980 if let syn::Expr::Lit(syn::ExprLit {
3981 lit: syn::Lit::Str(lit_str),
3982 ..
3983 }) = &nv.value
3984 {
3985 profile = Some(lit_str.value());
3986 }
3987 }
3988 _ => {}
3989 }
3990 }
3991
3992 Ok(AgentAttrs {
3993 expertise,
3994 output,
3995 backend,
3996 model,
3997 inner,
3998 default_inner,
3999 max_retries,
4000 profile,
4001 })
4002 }
4003}
4004
4005fn parse_agent_attrs(attrs: &[syn::Attribute]) -> syn::Result<AgentAttrs> {
4007 for attr in attrs {
4008 if attr.path().is_ident("agent") {
4009 return attr.parse_args::<AgentAttrs>();
4010 }
4011 }
4012
4013 Ok(AgentAttrs {
4014 expertise: None,
4015 output: None,
4016 backend: None,
4017 model: None,
4018 inner: None,
4019 default_inner: None,
4020 max_retries: None,
4021 profile: None,
4022 })
4023}
4024
4025fn generate_backend_constructors(
4027 struct_name: &syn::Ident,
4028 backend: &str,
4029 _model: Option<&str>,
4030 _profile: Option<&str>,
4031 crate_path: &proc_macro2::TokenStream,
4032) -> proc_macro2::TokenStream {
4033 match backend {
4034 "claude" => {
4035 quote! {
4036 impl #struct_name {
4037 pub fn with_claude() -> Self {
4039 Self::new(#crate_path::agent::impls::ClaudeCodeAgent::new())
4040 }
4041
4042 pub fn with_claude_model(model: &str) -> Self {
4044 Self::new(
4045 #crate_path::agent::impls::ClaudeCodeAgent::new()
4046 .with_model_str(model)
4047 )
4048 }
4049 }
4050 }
4051 }
4052 "gemini" => {
4053 quote! {
4054 impl #struct_name {
4055 pub fn with_gemini() -> Self {
4057 Self::new(#crate_path::agent::impls::GeminiAgent::new())
4058 }
4059
4060 pub fn with_gemini_model(model: &str) -> Self {
4062 Self::new(
4063 #crate_path::agent::impls::GeminiAgent::new()
4064 .with_model_str(model)
4065 )
4066 }
4067 }
4068 }
4069 }
4070 _ => quote! {},
4071 }
4072}
4073
4074fn generate_default_impl(
4076 struct_name: &syn::Ident,
4077 backend: &str,
4078 model: Option<&str>,
4079 profile: Option<&str>,
4080 crate_path: &proc_macro2::TokenStream,
4081) -> proc_macro2::TokenStream {
4082 let profile_expr = if let Some(profile_str) = profile {
4084 match profile_str.to_lowercase().as_str() {
4085 "creative" => quote! { #crate_path::agent::ExecutionProfile::Creative },
4086 "balanced" => quote! { #crate_path::agent::ExecutionProfile::Balanced },
4087 "deterministic" => quote! { #crate_path::agent::ExecutionProfile::Deterministic },
4088 _ => quote! { #crate_path::agent::ExecutionProfile::Balanced }, }
4090 } else {
4091 quote! { #crate_path::agent::ExecutionProfile::default() }
4092 };
4093
4094 let agent_init = match backend {
4095 "gemini" => {
4096 let mut builder = quote! { #crate_path::agent::impls::GeminiAgent::new() };
4097
4098 if let Some(model_str) = model {
4099 builder = quote! { #builder.with_model_str(#model_str) };
4100 }
4101
4102 builder = quote! { #builder.with_execution_profile(#profile_expr) };
4103 builder
4104 }
4105 _ => {
4106 let mut builder = quote! { #crate_path::agent::impls::ClaudeCodeAgent::new() };
4108
4109 if let Some(model_str) = model {
4110 builder = quote! { #builder.with_model_str(#model_str) };
4111 }
4112
4113 builder = quote! { #builder.with_execution_profile(#profile_expr) };
4114 builder
4115 }
4116 };
4117
4118 quote! {
4119 impl Default for #struct_name {
4120 fn default() -> Self {
4121 Self::new(#agent_init)
4122 }
4123 }
4124 }
4125}
4126
4127#[proc_macro_derive(Agent, attributes(agent))]
4136pub fn derive_agent(input: TokenStream) -> TokenStream {
4137 let input = parse_macro_input!(input as DeriveInput);
4138 let struct_name = &input.ident;
4139
4140 let agent_attrs = match parse_agent_attrs(&input.attrs) {
4142 Ok(attrs) => attrs,
4143 Err(e) => return e.to_compile_error().into(),
4144 };
4145
4146 let expertise = agent_attrs
4147 .expertise
4148 .unwrap_or_else(|| String::from("general AI assistant"));
4149 let output_type = agent_attrs
4150 .output
4151 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
4152 let backend = agent_attrs
4153 .backend
4154 .unwrap_or_else(|| String::from("claude"));
4155 let model = agent_attrs.model;
4156 let _profile = agent_attrs.profile; let max_retries = agent_attrs.max_retries.unwrap_or(3); let found_crate =
4161 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
4162 let crate_path = match found_crate {
4163 FoundCrate::Itself => {
4164 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
4166 quote!(::#ident)
4167 }
4168 FoundCrate::Name(name) => {
4169 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
4170 quote!(::#ident)
4171 }
4172 };
4173
4174 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
4175
4176 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
4178 let is_string_output = output_type_str == "String" || output_type_str == "&str";
4179
4180 let enhanced_expertise = if is_string_output {
4182 quote! { #expertise }
4184 } else {
4185 let type_name = quote!(#output_type).to_string();
4187 quote! {
4188 {
4189 use std::sync::OnceLock;
4190 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
4191
4192 EXPERTISE_CACHE.get_or_init(|| {
4193 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
4195
4196 if schema.is_empty() {
4197 format!(
4199 concat!(
4200 #expertise,
4201 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
4202 "Do not include any text outside the JSON object."
4203 ),
4204 #type_name
4205 )
4206 } else {
4207 format!(
4209 concat!(
4210 #expertise,
4211 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
4212 ),
4213 schema
4214 )
4215 }
4216 }).as_str()
4217 }
4218 }
4219 };
4220
4221 let agent_init = match backend.as_str() {
4223 "gemini" => {
4224 if let Some(model_str) = model {
4225 quote! {
4226 use #crate_path::agent::impls::GeminiAgent;
4227 let agent = GeminiAgent::new().with_model_str(#model_str);
4228 }
4229 } else {
4230 quote! {
4231 use #crate_path::agent::impls::GeminiAgent;
4232 let agent = GeminiAgent::new();
4233 }
4234 }
4235 }
4236 "claude" => {
4237 if let Some(model_str) = model {
4238 quote! {
4239 use #crate_path::agent::impls::ClaudeCodeAgent;
4240 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
4241 }
4242 } else {
4243 quote! {
4244 use #crate_path::agent::impls::ClaudeCodeAgent;
4245 let agent = ClaudeCodeAgent::new();
4246 }
4247 }
4248 }
4249 _ => {
4250 if let Some(model_str) = model {
4252 quote! {
4253 use #crate_path::agent::impls::ClaudeCodeAgent;
4254 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
4255 }
4256 } else {
4257 quote! {
4258 use #crate_path::agent::impls::ClaudeCodeAgent;
4259 let agent = ClaudeCodeAgent::new();
4260 }
4261 }
4262 }
4263 };
4264
4265 let expanded = quote! {
4266 #[async_trait::async_trait]
4267 impl #impl_generics #crate_path::agent::Agent for #struct_name #ty_generics #where_clause {
4268 type Output = #output_type;
4269
4270 fn expertise(&self) -> &str {
4271 #enhanced_expertise
4272 }
4273
4274 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
4275 #agent_init
4277
4278 let agent_ref = &agent;
4280 #crate_path::agent::retry::retry_execution(
4281 #max_retries,
4282 &intent,
4283 move |payload| {
4284 let payload = payload.clone();
4285 async move {
4286 let response = agent_ref.execute(payload).await?;
4288
4289 let json_str = #crate_path::extract_json(&response)
4291 .map_err(|e| #crate_path::agent::AgentError::ParseError {
4292 message: format!("Failed to extract JSON: {}", e),
4293 reason: #crate_path::agent::error::ParseErrorReason::MarkdownExtractionFailed,
4294 })?;
4295
4296 serde_json::from_str::<Self::Output>(&json_str)
4298 .map_err(|e| {
4299 let reason = if e.is_eof() {
4301 #crate_path::agent::error::ParseErrorReason::UnexpectedEof
4302 } else if e.is_syntax() {
4303 #crate_path::agent::error::ParseErrorReason::InvalidJson
4304 } else {
4305 #crate_path::agent::error::ParseErrorReason::SchemaMismatch
4306 };
4307
4308 #crate_path::agent::AgentError::ParseError {
4309 message: format!("Failed to parse JSON: {}", e),
4310 reason,
4311 }
4312 })
4313 }
4314 }
4315 ).await
4316 }
4317
4318 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
4319 #agent_init
4321 agent.is_available().await
4322 }
4323 }
4324 };
4325
4326 TokenStream::from(expanded)
4327}
4328
4329#[proc_macro_attribute]
4344pub fn agent(attr: TokenStream, item: TokenStream) -> TokenStream {
4345 let agent_attrs = match syn::parse::<AgentAttrs>(attr) {
4347 Ok(attrs) => attrs,
4348 Err(e) => return e.to_compile_error().into(),
4349 };
4350
4351 let input = parse_macro_input!(item as DeriveInput);
4353 let struct_name = &input.ident;
4354 let vis = &input.vis;
4355
4356 let expertise = agent_attrs
4357 .expertise
4358 .unwrap_or_else(|| String::from("general AI assistant"));
4359 let output_type = agent_attrs
4360 .output
4361 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
4362 let backend = agent_attrs
4363 .backend
4364 .unwrap_or_else(|| String::from("claude"));
4365 let model = agent_attrs.model;
4366 let profile = agent_attrs.profile;
4367
4368 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
4370 let is_string_output = output_type_str == "String" || output_type_str == "&str";
4371
4372 let found_crate =
4374 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
4375 let crate_path = match found_crate {
4376 FoundCrate::Itself => {
4377 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
4378 quote!(::#ident)
4379 }
4380 FoundCrate::Name(name) => {
4381 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
4382 quote!(::#ident)
4383 }
4384 };
4385
4386 let inner_generic_name = agent_attrs.inner.unwrap_or_else(|| String::from("A"));
4388 let inner_generic_ident = syn::Ident::new(&inner_generic_name, proc_macro2::Span::call_site());
4389
4390 let default_agent_type = if let Some(ref custom_type) = agent_attrs.default_inner {
4392 let type_path: syn::Type =
4394 syn::parse_str(custom_type).expect("default_inner must be a valid type path");
4395 quote! { #type_path }
4396 } else {
4397 match backend.as_str() {
4399 "gemini" => quote! { #crate_path::agent::impls::GeminiAgent },
4400 _ => quote! { #crate_path::agent::impls::ClaudeCodeAgent },
4401 }
4402 };
4403
4404 let struct_def = quote! {
4406 #vis struct #struct_name<#inner_generic_ident = #default_agent_type> {
4407 inner: #inner_generic_ident,
4408 }
4409 };
4410
4411 let constructors = quote! {
4413 impl<#inner_generic_ident> #struct_name<#inner_generic_ident> {
4414 pub fn new(inner: #inner_generic_ident) -> Self {
4416 Self { inner }
4417 }
4418 }
4419 };
4420
4421 let (backend_constructors, default_impl) = if agent_attrs.default_inner.is_some() {
4423 let default_impl = quote! {
4425 impl Default for #struct_name {
4426 fn default() -> Self {
4427 Self {
4428 inner: <#default_agent_type as Default>::default(),
4429 }
4430 }
4431 }
4432 };
4433 (quote! {}, default_impl)
4434 } else {
4435 let backend_constructors = generate_backend_constructors(
4437 struct_name,
4438 &backend,
4439 model.as_deref(),
4440 profile.as_deref(),
4441 &crate_path,
4442 );
4443 let default_impl = generate_default_impl(
4444 struct_name,
4445 &backend,
4446 model.as_deref(),
4447 profile.as_deref(),
4448 &crate_path,
4449 );
4450 (backend_constructors, default_impl)
4451 };
4452
4453 let enhanced_expertise = if is_string_output {
4455 quote! { #expertise }
4457 } else {
4458 let type_name = quote!(#output_type).to_string();
4460 quote! {
4461 {
4462 use std::sync::OnceLock;
4463 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
4464
4465 EXPERTISE_CACHE.get_or_init(|| {
4466 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
4468
4469 if schema.is_empty() {
4470 format!(
4472 concat!(
4473 #expertise,
4474 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
4475 "Do not include any text outside the JSON object."
4476 ),
4477 #type_name
4478 )
4479 } else {
4480 format!(
4482 concat!(
4483 #expertise,
4484 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
4485 ),
4486 schema
4487 )
4488 }
4489 }).as_str()
4490 }
4491 }
4492 };
4493
4494 let agent_impl = quote! {
4496 #[async_trait::async_trait]
4497 impl<#inner_generic_ident> #crate_path::agent::Agent for #struct_name<#inner_generic_ident>
4498 where
4499 #inner_generic_ident: #crate_path::agent::Agent<Output = String>,
4500 {
4501 type Output = #output_type;
4502
4503 fn expertise(&self) -> &str {
4504 #enhanced_expertise
4505 }
4506
4507 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
4508 let enhanced_payload = intent.prepend_text(self.expertise());
4510
4511 let response = self.inner.execute(enhanced_payload).await?;
4513
4514 let json_str = #crate_path::extract_json(&response)
4516 .map_err(|e| #crate_path::agent::AgentError::ParseError {
4517 message: e.to_string(),
4518 reason: #crate_path::agent::error::ParseErrorReason::MarkdownExtractionFailed,
4519 })?;
4520
4521 serde_json::from_str(&json_str).map_err(|e| {
4523 let reason = if e.is_eof() {
4524 #crate_path::agent::error::ParseErrorReason::UnexpectedEof
4525 } else if e.is_syntax() {
4526 #crate_path::agent::error::ParseErrorReason::InvalidJson
4527 } else {
4528 #crate_path::agent::error::ParseErrorReason::SchemaMismatch
4529 };
4530 #crate_path::agent::AgentError::ParseError {
4531 message: e.to_string(),
4532 reason,
4533 }
4534 })
4535 }
4536
4537 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
4538 self.inner.is_available().await
4539 }
4540 }
4541 };
4542
4543 let expanded = quote! {
4544 #struct_def
4545 #constructors
4546 #backend_constructors
4547 #default_impl
4548 #agent_impl
4549 };
4550
4551 TokenStream::from(expanded)
4552}
4553
4554#[proc_macro_derive(TypeMarker)]
4576pub fn derive_type_marker(input: TokenStream) -> TokenStream {
4577 let input = parse_macro_input!(input as DeriveInput);
4578 let struct_name = &input.ident;
4579 let type_name_str = struct_name.to_string();
4580
4581 let found_crate =
4583 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
4584 let crate_path = match found_crate {
4585 FoundCrate::Itself => {
4586 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
4587 quote!(::#ident)
4588 }
4589 FoundCrate::Name(name) => {
4590 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
4591 quote!(::#ident)
4592 }
4593 };
4594
4595 let expanded = quote! {
4596 impl #crate_path::orchestrator::TypeMarker for #struct_name {
4597 const TYPE_NAME: &'static str = #type_name_str;
4598 }
4599 };
4600
4601 TokenStream::from(expanded)
4602}
4603
4604#[proc_macro_attribute]
4640pub fn type_marker(_attr: TokenStream, item: TokenStream) -> TokenStream {
4641 let input = parse_macro_input!(item as syn::DeriveInput);
4642 let struct_name = &input.ident;
4643 let vis = &input.vis;
4644 let type_name_str = struct_name.to_string();
4645
4646 let default_fn_name = syn::Ident::new(
4648 &format!("default_{}_type", to_snake_case(&type_name_str)),
4649 struct_name.span(),
4650 );
4651
4652 let found_crate =
4654 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
4655 let crate_path = match found_crate {
4656 FoundCrate::Itself => {
4657 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
4658 quote!(::#ident)
4659 }
4660 FoundCrate::Name(name) => {
4661 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
4662 quote!(::#ident)
4663 }
4664 };
4665
4666 let fields = match &input.data {
4668 syn::Data::Struct(data_struct) => match &data_struct.fields {
4669 syn::Fields::Named(fields) => &fields.named,
4670 _ => {
4671 return syn::Error::new_spanned(
4672 struct_name,
4673 "type_marker only works with structs with named fields",
4674 )
4675 .to_compile_error()
4676 .into();
4677 }
4678 },
4679 _ => {
4680 return syn::Error::new_spanned(struct_name, "type_marker only works with structs")
4681 .to_compile_error()
4682 .into();
4683 }
4684 };
4685
4686 let mut new_fields = vec![];
4688
4689 let default_fn_name_str = default_fn_name.to_string();
4691 let default_fn_name_lit = syn::LitStr::new(&default_fn_name_str, default_fn_name.span());
4692
4693 new_fields.push(quote! {
4698 #[serde(default = #default_fn_name_lit)]
4699 __type: String
4700 });
4701
4702 for field in fields {
4704 new_fields.push(quote! { #field });
4705 }
4706
4707 let attrs = &input.attrs;
4709 let generics = &input.generics;
4710
4711 let expanded = quote! {
4712 fn #default_fn_name() -> String {
4714 #type_name_str.to_string()
4715 }
4716
4717 #(#attrs)*
4719 #vis struct #struct_name #generics {
4720 #(#new_fields),*
4721 }
4722
4723 impl #crate_path::orchestrator::TypeMarker for #struct_name {
4725 const TYPE_NAME: &'static str = #type_name_str;
4726 }
4727 };
4728
4729 TokenStream::from(expanded)
4730}