1use proc_macro::TokenStream;
2use proc_macro_crate::{FoundCrate, crate_name};
3use quote::quote;
4use regex::Regex;
5use syn::{
6 Data, DeriveInput, Meta, Token,
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 punctuated::Punctuated,
10};
11
12fn parse_template_placeholders_with_mode(template: &str) -> Vec<(String, Option<String>)> {
15 let mut placeholders = Vec::new();
16 let mut seen_fields = std::collections::HashSet::new();
17
18 let mode_pattern = Regex::new(r"\{\{\s*(\w+)\s*:\s*(\w+)\s*\}\}").unwrap();
20 for cap in mode_pattern.captures_iter(template) {
21 let field_name = cap[1].to_string();
22 let mode = cap[2].to_string();
23 placeholders.push((field_name.clone(), Some(mode)));
24 seen_fields.insert(field_name);
25 }
26
27 let standard_pattern = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
29 for cap in standard_pattern.captures_iter(template) {
30 let field_name = cap[1].to_string();
31 if !seen_fields.contains(&field_name) {
33 placeholders.push((field_name, None));
34 }
35 }
36
37 placeholders
38}
39
40fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
42 attrs
43 .iter()
44 .filter_map(|attr| {
45 if attr.path().is_ident("doc")
46 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
47 && let syn::Expr::Lit(syn::ExprLit {
48 lit: syn::Lit::Str(lit_str),
49 ..
50 }) = &meta_name_value.value
51 {
52 return Some(lit_str.value());
53 }
54 None
55 })
56 .map(|s| s.trim().to_string())
57 .collect::<Vec<_>>()
58 .join(" ")
59}
60
61fn generate_example_only_parts(
63 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
64 has_default: bool,
65 crate_path: &proc_macro2::TokenStream,
66) -> proc_macro2::TokenStream {
67 let mut field_values = Vec::new();
68
69 for field in fields.iter() {
70 let field_name = field.ident.as_ref().unwrap();
71 let field_name_str = field_name.to_string();
72 let attrs = parse_field_prompt_attrs(&field.attrs);
73
74 if field_name_str == "__type" {
78 continue;
79 }
80
81 if attrs.skip {
83 continue;
84 }
85
86 if let Some(example) = attrs.example {
88 field_values.push(quote! {
90 json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
91 });
92 } else if has_default {
93 field_values.push(quote! {
95 let default_value = serde_json::to_value(&default_instance.#field_name)
96 .unwrap_or(serde_json::Value::Null);
97 json_obj.insert(#field_name_str.to_string(), default_value);
98 });
99 } else {
100 field_values.push(quote! {
102 let value = serde_json::to_value(&self.#field_name)
103 .unwrap_or(serde_json::Value::Null);
104 json_obj.insert(#field_name_str.to_string(), value);
105 });
106 }
107 }
108
109 if has_default {
110 quote! {
111 {
112 let default_instance = Self::default();
113 let mut json_obj = serde_json::Map::new();
114 #(#field_values)*
115 let json_value = serde_json::Value::Object(json_obj);
116 let json_str = serde_json::to_string_pretty(&json_value)
117 .unwrap_or_else(|_| "{}".to_string());
118 vec![#crate_path::prompt::PromptPart::Text(json_str)]
119 }
120 }
121 } else {
122 quote! {
123 {
124 let mut json_obj = serde_json::Map::new();
125 #(#field_values)*
126 let json_value = serde_json::Value::Object(json_obj);
127 let json_str = serde_json::to_string_pretty(&json_value)
128 .unwrap_or_else(|_| "{}".to_string());
129 vec![#crate_path::prompt::PromptPart::Text(json_str)]
130 }
131 }
132 }
133}
134
135fn generate_schema_only_parts(
137 struct_name: &str,
138 struct_docs: &str,
139 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
140 crate_path: &proc_macro2::TokenStream,
141 _has_type_marker: bool,
142) -> proc_macro2::TokenStream {
143 let mut field_schema_parts = vec![];
144 let mut nested_type_collectors = vec![];
145
146 for field in fields.iter() {
148 let field_name = field.ident.as_ref().unwrap();
149 let field_name_str = field_name.to_string();
150 let attrs = parse_field_prompt_attrs(&field.attrs);
151
152 if field_name_str == "__type" {
156 continue;
157 }
158
159 if attrs.skip {
161 continue;
162 }
163
164 let field_docs = extract_doc_comments(&field.attrs);
166
167 let (is_vec, vec_inner_type) = extract_vec_inner_type(&field.ty);
169 let (is_option, option_inner_type) = extract_option_inner_type(&field.ty);
170 let (is_map, map_value_type) = extract_map_value_type(&field.ty);
171 let (is_set, set_element_type) = extract_set_element_type(&field.ty);
172
173 if is_vec {
174 let comment = if !field_docs.is_empty() {
177 format!(" // {}", field_docs)
178 } else {
179 String::new()
180 };
181
182 field_schema_parts.push(quote! {
183 {
184 let type_name = stringify!(#vec_inner_type);
185 format!(" {}: {}[];{}", #field_name_str, type_name, #comment)
186 }
187 });
188
189 if let Some(inner) = vec_inner_type
191 && !is_primitive_type(inner)
192 {
193 nested_type_collectors.push(quote! {
194 <#inner as #crate_path::prompt::ToPrompt>::prompt_schema()
195 });
196 }
197 } else if is_option {
198 let comment = if !field_docs.is_empty() {
201 format!(" // {}", field_docs)
202 } else {
203 String::new()
204 };
205
206 if let Some(inner) = option_inner_type {
207 let (inner_is_map, inner_map_value) = extract_map_value_type(inner);
209 let (inner_is_set, inner_set_element) = extract_set_element_type(inner);
210 let (inner_is_vec, inner_vec_element) = extract_vec_inner_type(inner);
211
212 if inner_is_map {
213 if let Some(value_type) = inner_map_value {
215 let is_value_primitive = is_primitive_type(value_type);
216 if !is_value_primitive {
217 field_schema_parts.push(quote! {
218 {
219 let type_name = stringify!(#value_type);
220 format!(" {}: Record<string, {}> | null;{}", #field_name_str, type_name, #comment)
221 }
222 });
223 nested_type_collectors.push(quote! {
224 <#value_type as #crate_path::prompt::ToPrompt>::prompt_schema()
225 });
226 } else {
227 let type_str = format_type_for_schema(value_type);
228 field_schema_parts.push(quote! {
229 format!(" {}: Record<string, {}> | null;{}", #field_name_str, #type_str, #comment)
230 });
231 }
232 }
233 } else if inner_is_set {
234 if let Some(element_type) = inner_set_element {
236 let is_element_primitive = is_primitive_type(element_type);
237 if !is_element_primitive {
238 field_schema_parts.push(quote! {
239 {
240 let type_name = stringify!(#element_type);
241 format!(" {}: {}[] | null;{}", #field_name_str, type_name, #comment)
242 }
243 });
244 nested_type_collectors.push(quote! {
245 <#element_type as #crate_path::prompt::ToPrompt>::prompt_schema()
246 });
247 } else {
248 let type_str = format_type_for_schema(element_type);
249 field_schema_parts.push(quote! {
250 format!(" {}: {}[] | null;{}", #field_name_str, #type_str, #comment)
251 });
252 }
253 }
254 } else if inner_is_vec {
255 if let Some(element_type) = inner_vec_element {
257 let is_element_primitive = is_primitive_type(element_type);
258 if !is_element_primitive {
259 field_schema_parts.push(quote! {
260 {
261 let type_name = stringify!(#element_type);
262 format!(" {}: {}[] | null;{}", #field_name_str, type_name, #comment)
263 }
264 });
265 nested_type_collectors.push(quote! {
266 <#element_type as #crate_path::prompt::ToPrompt>::prompt_schema()
267 });
268 } else {
269 let type_str = format_type_for_schema(element_type);
270 field_schema_parts.push(quote! {
271 format!(" {}: {}[] | null;{}", #field_name_str, #type_str, #comment)
272 });
273 }
274 }
275 } else {
276 let is_inner_primitive = is_primitive_type(inner);
278
279 if !is_inner_primitive {
280 field_schema_parts.push(quote! {
282 {
283 let type_name = stringify!(#inner);
284 format!(" {}: {} | null;{}", #field_name_str, type_name, #comment)
285 }
286 });
287
288 nested_type_collectors.push(quote! {
290 <#inner as #crate_path::prompt::ToPrompt>::prompt_schema()
291 });
292 } else {
293 let type_str = format_type_for_schema(inner);
295 field_schema_parts.push(quote! {
296 format!(" {}: {} | null;{}", #field_name_str, #type_str, #comment)
297 });
298 }
299 }
300 }
301 } else if is_map {
302 let comment = if !field_docs.is_empty() {
305 format!(" // {}", field_docs)
306 } else {
307 String::new()
308 };
309
310 if let Some(value_type) = map_value_type {
311 let is_value_primitive = is_primitive_type(value_type);
312
313 if !is_value_primitive {
314 field_schema_parts.push(quote! {
316 {
317 let type_name = stringify!(#value_type);
318 format!(" {}: Record<string, {}>;{}", #field_name_str, type_name, #comment)
319 }
320 });
321
322 nested_type_collectors.push(quote! {
324 <#value_type as #crate_path::prompt::ToPrompt>::prompt_schema()
325 });
326 } else {
327 let type_str = format_type_for_schema(value_type);
329 field_schema_parts.push(quote! {
330 format!(" {}: Record<string, {}>;{}", #field_name_str, #type_str, #comment)
331 });
332 }
333 }
334 } else if is_set {
335 let comment = if !field_docs.is_empty() {
338 format!(" // {}", field_docs)
339 } else {
340 String::new()
341 };
342
343 if let Some(element_type) = set_element_type {
344 let is_element_primitive = is_primitive_type(element_type);
345
346 if !is_element_primitive {
347 field_schema_parts.push(quote! {
349 {
350 let type_name = stringify!(#element_type);
351 format!(" {}: {}[];{}", #field_name_str, type_name, #comment)
352 }
353 });
354
355 nested_type_collectors.push(quote! {
357 <#element_type as #crate_path::prompt::ToPrompt>::prompt_schema()
358 });
359 } else {
360 let type_str = format_type_for_schema(element_type);
362 field_schema_parts.push(quote! {
363 format!(" {}: {}[];{}", #field_name_str, #type_str, #comment)
364 });
365 }
366 }
367 } else {
368 let field_type = &field.ty;
370 let is_primitive = is_primitive_type(field_type);
371
372 if !is_primitive {
373 let comment = if !field_docs.is_empty() {
376 format!(" // {}", field_docs)
377 } else {
378 String::new()
379 };
380
381 field_schema_parts.push(quote! {
382 {
383 let type_name = stringify!(#field_type);
384 format!(" {}: {};{}", #field_name_str, type_name, #comment)
385 }
386 });
387
388 nested_type_collectors.push(quote! {
390 <#field_type as #crate_path::prompt::ToPrompt>::prompt_schema()
391 });
392 } else {
393 let type_str = format_type_for_schema(&field.ty);
396 let comment = if !field_docs.is_empty() {
397 format!(" // {}", field_docs)
398 } else {
399 String::new()
400 };
401
402 field_schema_parts.push(quote! {
403 format!(" {}: {};{}", #field_name_str, #type_str, #comment)
404 });
405 }
406 }
407 }
408
409 let mut header_lines = Vec::new();
424
425 if !struct_docs.is_empty() {
427 header_lines.push("/**".to_string());
428 header_lines.push(format!(" * {}", struct_docs));
429 header_lines.push(" */".to_string());
430 }
431
432 header_lines.push(format!("type {} = {{", struct_name));
434
435 quote! {
436 {
437 let mut all_lines: Vec<String> = Vec::new();
438
439 let nested_schemas: Vec<String> = vec![#(#nested_type_collectors),*];
441 let mut seen_types = std::collections::HashSet::<String>::new();
442
443 for schema in nested_schemas {
444 if !schema.is_empty() {
445 if seen_types.insert(schema.clone()) {
447 all_lines.push(schema);
448 all_lines.push(String::new()); }
450 }
451 }
452
453 let mut lines: Vec<String> = Vec::new();
455 #(lines.push(#header_lines.to_string());)*
456 #(lines.push(#field_schema_parts);)*
457 lines.push("}".to_string());
458 all_lines.push(lines.join("\n"));
459
460 vec![#crate_path::prompt::PromptPart::Text(all_lines.join("\n"))]
461 }
462 }
463}
464
465fn extract_vec_inner_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
467 if let syn::Type::Path(type_path) = ty
468 && let Some(last_segment) = type_path.path.segments.last()
469 && last_segment.ident == "Vec"
470 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
471 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
472 {
473 return (true, Some(inner_type));
474 }
475 (false, None)
476}
477
478fn extract_option_inner_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
480 if let syn::Type::Path(type_path) = ty
481 && let Some(last_segment) = type_path.path.segments.last()
482 && last_segment.ident == "Option"
483 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
484 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
485 {
486 return (true, Some(inner_type));
487 }
488 (false, None)
489}
490
491fn extract_map_value_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
493 if let syn::Type::Path(type_path) = ty
494 && let Some(last_segment) = type_path.path.segments.last()
495 && (last_segment.ident == "HashMap" || last_segment.ident == "BTreeMap")
496 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
497 && let Some(syn::GenericArgument::Type(value_type)) = args.args.iter().nth(1)
498 {
499 return (true, Some(value_type));
500 }
501 (false, None)
502}
503
504fn extract_set_element_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
506 if let syn::Type::Path(type_path) = ty
507 && let Some(last_segment) = type_path.path.segments.last()
508 && (last_segment.ident == "HashSet" || last_segment.ident == "BTreeSet")
509 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
510 && let Some(syn::GenericArgument::Type(element_type)) = args.args.first()
511 {
512 return (true, Some(element_type));
513 }
514 (false, None)
515}
516
517fn extract_expandable_type(ty: &syn::Type) -> &syn::Type {
520 if let syn::Type::Path(type_path) = ty
521 && let Some(last_segment) = type_path.path.segments.last()
522 {
523 let type_name = last_segment.ident.to_string();
524
525 if type_name == "Option"
527 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
528 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
529 {
530 return extract_expandable_type(inner_type);
531 }
532
533 if type_name == "Vec"
535 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
536 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
537 {
538 return extract_expandable_type(inner_type);
539 }
540
541 if (type_name == "HashSet" || type_name == "BTreeSet")
543 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
544 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
545 {
546 return extract_expandable_type(inner_type);
547 }
548
549 if (type_name == "HashMap" || type_name == "BTreeMap")
552 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
553 && let Some(syn::GenericArgument::Type(value_type)) = args.args.iter().nth(1)
554 {
555 return extract_expandable_type(value_type);
556 }
557 }
558
559 ty
561}
562
563fn is_primitive_type(ty: &syn::Type) -> bool {
565 if let syn::Type::Path(type_path) = ty
566 && let Some(last_segment) = type_path.path.segments.last()
567 {
568 let type_name = last_segment.ident.to_string();
569
570 if type_name == "Option"
572 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
573 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
574 {
575 return is_primitive_type(inner_type);
576 }
577
578 if type_name == "Vec"
580 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
581 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
582 {
583 return is_primitive_type(inner_type);
584 }
585
586 if (type_name == "HashSet" || type_name == "BTreeSet")
588 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
589 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
590 {
591 return is_primitive_type(inner_type);
592 }
593
594 if (type_name == "HashMap" || type_name == "BTreeMap")
597 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
598 && let Some(syn::GenericArgument::Type(value_type)) = args.args.iter().nth(1)
599 {
600 return is_primitive_type(value_type);
601 }
602
603 matches!(
604 type_name.as_str(),
605 "String"
606 | "str"
607 | "i8"
608 | "i16"
609 | "i32"
610 | "i64"
611 | "i128"
612 | "isize"
613 | "u8"
614 | "u16"
615 | "u32"
616 | "u64"
617 | "u128"
618 | "usize"
619 | "f32"
620 | "f64"
621 | "bool"
622 )
623 } else {
624 true
626 }
627}
628
629fn format_type_for_schema(ty: &syn::Type) -> String {
631 match ty {
633 syn::Type::Path(type_path) => {
634 let path = &type_path.path;
635 if let Some(last_segment) = path.segments.last() {
636 let type_name = last_segment.ident.to_string();
637
638 if type_name == "Option"
640 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
641 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
642 {
643 return format!("{} | null", format_type_for_schema(inner_type));
644 }
645
646 match type_name.as_str() {
648 "String" | "str" => "string".to_string(),
649 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
650 | "u64" | "u128" | "usize" => "number".to_string(),
651 "f32" | "f64" => "number".to_string(),
652 "bool" => "boolean".to_string(),
653 "Vec" => {
654 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
655 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
656 {
657 return format!("{}[]", format_type_for_schema(inner_type));
658 }
659 "array".to_string()
660 }
661 _ => type_name,
663 }
664 } else {
665 "unknown".to_string()
666 }
667 }
668 _ => "unknown".to_string(),
669 }
670}
671
672#[derive(Default)]
674struct PromptAttributes {
675 skip: bool,
676 rename: Option<String>,
677 description: Option<String>,
678}
679
680fn parse_prompt_attributes(attrs: &[syn::Attribute]) -> PromptAttributes {
683 let mut result = PromptAttributes::default();
684
685 for attr in attrs {
686 if attr.path().is_ident("prompt") {
687 if let Ok(meta_list) = attr.meta.require_list() {
689 if let Ok(metas) =
691 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
692 {
693 for meta in metas {
694 if let Meta::NameValue(nv) = meta {
695 if nv.path.is_ident("rename") {
696 if let syn::Expr::Lit(syn::ExprLit {
697 lit: syn::Lit::Str(lit_str),
698 ..
699 }) = nv.value
700 {
701 result.rename = Some(lit_str.value());
702 }
703 } else if nv.path.is_ident("description")
704 && let syn::Expr::Lit(syn::ExprLit {
705 lit: syn::Lit::Str(lit_str),
706 ..
707 }) = nv.value
708 {
709 result.description = Some(lit_str.value());
710 }
711 } else if let Meta::Path(path) = meta
712 && path.is_ident("skip")
713 {
714 result.skip = true;
715 }
716 }
717 }
718
719 let tokens_str = meta_list.tokens.to_string();
721 if tokens_str == "skip" {
722 result.skip = true;
723 }
724 }
725
726 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
728 result.description = Some(lit_str.value());
729 }
730 }
731 }
732 result
733}
734
735fn generate_example_value_for_type(type_str: &str) -> String {
737 match type_str {
738 "string" => "\"example\"".to_string(),
739 "number" => "0".to_string(),
740 "boolean" => "false".to_string(),
741 s if s.ends_with("[]") => "[]".to_string(),
742 s if s.contains("|") => {
743 let first_type = s.split('|').next().unwrap().trim();
745 generate_example_value_for_type(first_type)
746 }
747 custom_type => {
748 format!("\"<See {}>\"", custom_type)
752 }
753 }
754}
755
756fn parse_serde_variant_rename(attrs: &[syn::Attribute]) -> Option<String> {
758 for attr in attrs {
759 if attr.path().is_ident("serde")
760 && let Ok(meta_list) = attr.meta.require_list()
761 && let Ok(metas) =
762 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
763 {
764 for meta in metas {
765 if let Meta::NameValue(nv) = meta
766 && nv.path.is_ident("rename")
767 && let syn::Expr::Lit(syn::ExprLit {
768 lit: syn::Lit::Str(lit_str),
769 ..
770 }) = nv.value
771 {
772 return Some(lit_str.value());
773 }
774 }
775 }
776 }
777 None
778}
779
780#[derive(Debug, Clone, Copy, PartialEq, Eq)]
782enum RenameRule {
783 #[allow(dead_code)]
784 None,
785 LowerCase,
786 UpperCase,
787 PascalCase,
788 CamelCase,
789 SnakeCase,
790 ScreamingSnakeCase,
791 KebabCase,
792 ScreamingKebabCase,
793}
794
795impl RenameRule {
796 fn from_str(s: &str) -> Option<Self> {
798 match s {
799 "lowercase" => Some(Self::LowerCase),
800 "UPPERCASE" => Some(Self::UpperCase),
801 "PascalCase" => Some(Self::PascalCase),
802 "camelCase" => Some(Self::CamelCase),
803 "snake_case" => Some(Self::SnakeCase),
804 "SCREAMING_SNAKE_CASE" => Some(Self::ScreamingSnakeCase),
805 "kebab-case" => Some(Self::KebabCase),
806 "SCREAMING-KEBAB-CASE" => Some(Self::ScreamingKebabCase),
807 _ => None,
808 }
809 }
810
811 fn apply(&self, name: &str) -> String {
813 match self {
814 Self::None => name.to_string(),
815 Self::LowerCase => name.to_lowercase(),
816 Self::UpperCase => name.to_uppercase(),
817 Self::PascalCase => name.to_string(), Self::CamelCase => {
819 let mut chars = name.chars();
821 match chars.next() {
822 None => String::new(),
823 Some(first) => first.to_lowercase().chain(chars).collect(),
824 }
825 }
826 Self::SnakeCase => {
827 let mut result = String::new();
829 for (i, ch) in name.chars().enumerate() {
830 if ch.is_uppercase() && i > 0 {
831 result.push('_');
832 }
833 result.push(ch.to_lowercase().next().unwrap());
834 }
835 result
836 }
837 Self::ScreamingSnakeCase => {
838 let mut result = String::new();
840 for (i, ch) in name.chars().enumerate() {
841 if ch.is_uppercase() && i > 0 {
842 result.push('_');
843 }
844 result.push(ch.to_uppercase().next().unwrap());
845 }
846 result
847 }
848 Self::KebabCase => {
849 let mut result = String::new();
851 for (i, ch) in name.chars().enumerate() {
852 if ch.is_uppercase() && i > 0 {
853 result.push('-');
854 }
855 result.push(ch.to_lowercase().next().unwrap());
856 }
857 result
858 }
859 Self::ScreamingKebabCase => {
860 let mut result = String::new();
862 for (i, ch) in name.chars().enumerate() {
863 if ch.is_uppercase() && i > 0 {
864 result.push('-');
865 }
866 result.push(ch.to_uppercase().next().unwrap());
867 }
868 result
869 }
870 }
871 }
872}
873
874fn parse_serde_rename_all(attrs: &[syn::Attribute]) -> Option<RenameRule> {
876 for attr in attrs {
877 if attr.path().is_ident("serde")
878 && let Ok(meta_list) = attr.meta.require_list()
879 {
880 if let Ok(metas) =
882 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
883 {
884 for meta in metas {
885 if let Meta::NameValue(nv) = meta
886 && nv.path.is_ident("rename_all")
887 && let syn::Expr::Lit(syn::ExprLit {
888 lit: syn::Lit::Str(lit_str),
889 ..
890 }) = nv.value
891 {
892 return RenameRule::from_str(&lit_str.value());
893 }
894 }
895 }
896 }
897 }
898 None
899}
900
901fn parse_serde_tag(attrs: &[syn::Attribute]) -> Option<String> {
904 for attr in attrs {
905 if attr.path().is_ident("serde")
906 && let Ok(meta_list) = attr.meta.require_list()
907 {
908 if let Ok(metas) =
910 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
911 {
912 for meta in metas {
913 if let Meta::NameValue(nv) = meta
914 && nv.path.is_ident("tag")
915 && let syn::Expr::Lit(syn::ExprLit {
916 lit: syn::Lit::Str(lit_str),
917 ..
918 }) = nv.value
919 {
920 return Some(lit_str.value());
921 }
922 }
923 }
924 }
925 }
926 None
927}
928
929fn parse_serde_untagged(attrs: &[syn::Attribute]) -> bool {
932 for attr in attrs {
933 if attr.path().is_ident("serde")
934 && let Ok(meta_list) = attr.meta.require_list()
935 {
936 if let Ok(metas) =
938 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
939 {
940 for meta in metas {
941 if let Meta::Path(path) = meta
942 && path.is_ident("untagged")
943 {
944 return true;
945 }
946 }
947 }
948 }
949 }
950 false
951}
952
953#[derive(Debug, Default)]
955struct FieldPromptAttrs {
956 skip: bool,
957 rename: Option<String>,
958 format_with: Option<String>,
959 image: bool,
960 example: Option<String>,
961}
962
963fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
965 let mut result = FieldPromptAttrs::default();
966
967 for attr in attrs {
968 if attr.path().is_ident("prompt") {
969 if let Ok(meta_list) = attr.meta.require_list() {
971 if let Ok(metas) =
973 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
974 {
975 for meta in metas {
976 match meta {
977 Meta::Path(path) if path.is_ident("skip") => {
978 result.skip = true;
979 }
980 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
981 if let syn::Expr::Lit(syn::ExprLit {
982 lit: syn::Lit::Str(lit_str),
983 ..
984 }) = nv.value
985 {
986 result.rename = Some(lit_str.value());
987 }
988 }
989 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
990 if let syn::Expr::Lit(syn::ExprLit {
991 lit: syn::Lit::Str(lit_str),
992 ..
993 }) = nv.value
994 {
995 result.format_with = Some(lit_str.value());
996 }
997 }
998 Meta::Path(path) if path.is_ident("image") => {
999 result.image = true;
1000 }
1001 Meta::NameValue(nv) if nv.path.is_ident("example") => {
1002 if let syn::Expr::Lit(syn::ExprLit {
1003 lit: syn::Lit::Str(lit_str),
1004 ..
1005 }) = nv.value
1006 {
1007 result.example = Some(lit_str.value());
1008 }
1009 }
1010 _ => {}
1011 }
1012 }
1013 } else if meta_list.tokens.to_string() == "skip" {
1014 result.skip = true;
1016 } else if meta_list.tokens.to_string() == "image" {
1017 result.image = true;
1019 }
1020 }
1021 }
1022 }
1023
1024 result
1025}
1026
1027#[proc_macro_derive(ToPrompt, attributes(prompt))]
1070pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
1071 let input = parse_macro_input!(input as DeriveInput);
1072
1073 let found_crate =
1074 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1075 let crate_path = match found_crate {
1076 FoundCrate::Itself => {
1077 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1079 quote!(::#ident)
1080 }
1081 FoundCrate::Name(name) => {
1082 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1083 quote!(::#ident)
1084 }
1085 };
1086
1087 match &input.data {
1089 Data::Enum(data_enum) => {
1090 let enum_name = &input.ident;
1092 let enum_docs = extract_doc_comments(&input.attrs);
1093
1094 let serde_tag = parse_serde_tag(&input.attrs);
1096 let is_internally_tagged = serde_tag.is_some();
1097 let is_untagged = parse_serde_untagged(&input.attrs);
1098
1099 let rename_rule = parse_serde_rename_all(&input.attrs);
1101
1102 let mut variant_lines = Vec::new();
1115 let mut first_variant_name = None;
1116
1117 let mut example_unit: Option<String> = None;
1119 let mut example_struct: Option<String> = None;
1120 let mut example_tuple: Option<String> = None;
1121
1122 let mut nested_types: Vec<&syn::Type> = Vec::new();
1124
1125 for variant in &data_enum.variants {
1126 let variant_name = &variant.ident;
1127 let variant_name_str = variant_name.to_string();
1128
1129 let prompt_attrs = parse_prompt_attributes(&variant.attrs);
1131
1132 if prompt_attrs.skip {
1134 continue;
1135 }
1136
1137 let variant_value = if let Some(prompt_rename) = &prompt_attrs.rename {
1143 prompt_rename.clone()
1144 } else if let Some(serde_rename) = parse_serde_variant_rename(&variant.attrs) {
1145 serde_rename
1146 } else if let Some(rule) = rename_rule {
1147 rule.apply(&variant_name_str)
1148 } else {
1149 variant_name_str.clone()
1150 };
1151
1152 let variant_line = match &variant.fields {
1154 syn::Fields::Unit => {
1155 if example_unit.is_none() {
1157 example_unit = Some(format!("\"{}\"", variant_value));
1158 }
1159
1160 if let Some(desc) = &prompt_attrs.description {
1162 format!(" | \"{}\" // {}", variant_value, desc)
1163 } else {
1164 let docs = extract_doc_comments(&variant.attrs);
1165 if !docs.is_empty() {
1166 format!(" | \"{}\" // {}", variant_value, docs)
1167 } else {
1168 format!(" | \"{}\"", variant_value)
1169 }
1170 }
1171 }
1172 syn::Fields::Named(fields) => {
1173 let mut field_parts = Vec::new();
1174 let mut example_field_parts = Vec::new();
1175
1176 if is_internally_tagged && let Some(tag_name) = &serde_tag {
1178 field_parts.push(format!("{}: \"{}\"", tag_name, variant_value));
1179 example_field_parts
1180 .push(format!("{}: \"{}\"", tag_name, variant_value));
1181 }
1182
1183 for field in &fields.named {
1184 let field_name = field.ident.as_ref().unwrap().to_string();
1185 let field_type = format_type_for_schema(&field.ty);
1186 field_parts.push(format!("{}: {}", field_name, field_type.clone()));
1187
1188 let expandable_type = extract_expandable_type(&field.ty);
1191 if !is_primitive_type(expandable_type) {
1192 nested_types.push(expandable_type);
1193 }
1194
1195 let example_value = generate_example_value_for_type(&field_type);
1197 example_field_parts.push(format!("{}: {}", field_name, example_value));
1198 }
1199
1200 let field_str = field_parts.join(", ");
1201 let example_field_str = example_field_parts.join(", ");
1202
1203 if example_struct.is_none() {
1205 if is_untagged || is_internally_tagged {
1206 example_struct = Some(format!("{{ {} }}", example_field_str));
1207 } else {
1208 example_struct = Some(format!(
1209 "{{ \"{}\": {{ {} }} }}",
1210 variant_value, example_field_str
1211 ));
1212 }
1213 }
1214
1215 let comment = if let Some(desc) = &prompt_attrs.description {
1216 format!(" // {}", desc)
1217 } else {
1218 let docs = extract_doc_comments(&variant.attrs);
1219 if !docs.is_empty() {
1220 format!(" // {}", docs)
1221 } else if is_untagged {
1222 format!(" // {}", variant_value)
1224 } else {
1225 String::new()
1226 }
1227 };
1228
1229 if is_untagged {
1230 format!(" | {{ {} }}{}", field_str, comment)
1232 } else if is_internally_tagged {
1233 format!(" | {{ {} }}{}", field_str, comment)
1235 } else {
1236 format!(
1238 " | {{ \"{}\": {{ {} }} }}{}",
1239 variant_value, field_str, comment
1240 )
1241 }
1242 }
1243 syn::Fields::Unnamed(fields) => {
1244 let field_types: Vec<String> = fields
1245 .unnamed
1246 .iter()
1247 .map(|f| {
1248 let expandable_type = extract_expandable_type(&f.ty);
1251 if !is_primitive_type(expandable_type) {
1252 nested_types.push(expandable_type);
1253 }
1254 format_type_for_schema(&f.ty)
1255 })
1256 .collect();
1257
1258 let tuple_str = field_types.join(", ");
1259
1260 let example_values: Vec<String> = field_types
1262 .iter()
1263 .map(|type_str| generate_example_value_for_type(type_str))
1264 .collect();
1265 let example_tuple_str = example_values.join(", ");
1266
1267 if example_tuple.is_none() {
1269 if is_untagged || is_internally_tagged {
1270 example_tuple = Some(format!("[{}]", example_tuple_str));
1271 } else {
1272 example_tuple = Some(format!(
1273 "{{ \"{}\": [{}] }}",
1274 variant_value, example_tuple_str
1275 ));
1276 }
1277 }
1278
1279 let comment = if let Some(desc) = &prompt_attrs.description {
1280 format!(" // {}", desc)
1281 } else {
1282 let docs = extract_doc_comments(&variant.attrs);
1283 if !docs.is_empty() {
1284 format!(" // {}", docs)
1285 } else if is_untagged {
1286 format!(" // {}", variant_value)
1288 } else {
1289 String::new()
1290 }
1291 };
1292
1293 if is_untagged || is_internally_tagged {
1294 format!(" | [{}]{}", tuple_str, comment)
1297 } else {
1298 format!(
1300 " | {{ \"{}\": [{}] }}{}",
1301 variant_value, tuple_str, comment
1302 )
1303 }
1304 }
1305 };
1306
1307 variant_lines.push(variant_line);
1308
1309 if first_variant_name.is_none() {
1310 first_variant_name = Some(variant_value);
1311 }
1312 }
1313
1314 let mut lines = Vec::new();
1316
1317 if !enum_docs.is_empty() {
1319 lines.push("/**".to_string());
1320 lines.push(format!(" * {}", enum_docs));
1321 lines.push(" */".to_string());
1322 }
1323
1324 lines.push(format!("type {} =", enum_name));
1326
1327 for line in &variant_lines {
1329 lines.push(line.clone());
1330 }
1331
1332 if let Some(last) = lines.last_mut()
1334 && !last.ends_with(';')
1335 {
1336 last.push(';');
1337 }
1338
1339 let mut examples = Vec::new();
1341 if let Some(ex) = example_unit {
1342 examples.push(ex);
1343 }
1344 if let Some(ex) = example_struct {
1345 examples.push(ex);
1346 }
1347 if let Some(ex) = example_tuple {
1348 examples.push(ex);
1349 }
1350
1351 if !examples.is_empty() {
1352 lines.push("".to_string()); if examples.len() == 1 {
1354 lines.push(format!("Example value: {}", examples[0]));
1355 } else {
1356 lines.push("Example values:".to_string());
1357 for ex in examples {
1358 lines.push(format!(" {}", ex));
1359 }
1360 }
1361 }
1362
1363 let nested_type_tokens: Vec<_> = nested_types
1365 .iter()
1366 .map(|field_ty| {
1367 quote! {
1368 {
1369 let type_schema = <#field_ty as #crate_path::prompt::ToPrompt>::prompt_schema();
1370 if !type_schema.is_empty() {
1371 format!("\n\n{}", type_schema)
1372 } else {
1373 String::new()
1374 }
1375 }
1376 }
1377 })
1378 .collect();
1379
1380 let prompt_string = if nested_type_tokens.is_empty() {
1381 let lines_str = lines.join("\n");
1382 quote! { #lines_str.to_string() }
1383 } else {
1384 let lines_str = lines.join("\n");
1385 quote! {
1386 {
1387 let mut result = String::from(#lines_str);
1388
1389 let nested_schemas: Vec<String> = vec![#(#nested_type_tokens),*];
1391 let mut seen_schemas = std::collections::HashSet::<String>::new();
1392
1393 for schema in nested_schemas {
1394 if !schema.is_empty() && seen_schemas.insert(schema.clone()) {
1395 result.push_str(&schema);
1396 }
1397 }
1398
1399 result
1400 }
1401 }
1402 };
1403 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1404
1405 let mut match_arms = Vec::new();
1407 for variant in &data_enum.variants {
1408 let variant_name = &variant.ident;
1409 let variant_name_str = variant_name.to_string();
1410
1411 let prompt_attrs = parse_prompt_attributes(&variant.attrs);
1413
1414 let variant_value = if let Some(prompt_rename) = &prompt_attrs.rename {
1420 prompt_rename.clone()
1421 } else if let Some(serde_rename) = parse_serde_variant_rename(&variant.attrs) {
1422 serde_rename
1423 } else if let Some(rule) = rename_rule {
1424 rule.apply(&variant_name_str)
1425 } else {
1426 variant_name_str.clone()
1427 };
1428
1429 match &variant.fields {
1431 syn::Fields::Unit => {
1432 if prompt_attrs.skip {
1434 match_arms.push(quote! {
1435 Self::#variant_name => stringify!(#variant_name).to_string()
1436 });
1437 } else if let Some(desc) = &prompt_attrs.description {
1438 match_arms.push(quote! {
1439 Self::#variant_name => format!("{}: {}", #variant_value, #desc)
1440 });
1441 } else {
1442 let variant_docs = extract_doc_comments(&variant.attrs);
1443 if !variant_docs.is_empty() {
1444 match_arms.push(quote! {
1445 Self::#variant_name => format!("{}: {}", #variant_value, #variant_docs)
1446 });
1447 } else {
1448 match_arms.push(quote! {
1449 Self::#variant_name => #variant_value.to_string()
1450 });
1451 }
1452 }
1453 }
1454 syn::Fields::Named(fields) => {
1455 let field_bindings: Vec<_> = fields
1457 .named
1458 .iter()
1459 .map(|f| f.ident.as_ref().unwrap())
1460 .collect();
1461
1462 let field_displays: Vec<_> = fields
1463 .named
1464 .iter()
1465 .map(|f| {
1466 let field_name = f.ident.as_ref().unwrap();
1467 let field_name_str = field_name.to_string();
1468 quote! {
1469 format!("{}: {:?}", #field_name_str, #field_name)
1470 }
1471 })
1472 .collect();
1473
1474 let doc_or_desc = if let Some(desc) = &prompt_attrs.description {
1475 desc.clone()
1476 } else {
1477 let docs = extract_doc_comments(&variant.attrs);
1478 if !docs.is_empty() {
1479 docs
1480 } else {
1481 String::new()
1482 }
1483 };
1484
1485 if doc_or_desc.is_empty() {
1486 match_arms.push(quote! {
1487 Self::#variant_name { #(#field_bindings),* } => {
1488 let fields = vec![#(#field_displays),*];
1489 format!("{} {{ {} }}", #variant_value, fields.join(", "))
1490 }
1491 });
1492 } else {
1493 match_arms.push(quote! {
1494 Self::#variant_name { #(#field_bindings),* } => {
1495 let fields = vec![#(#field_displays),*];
1496 format!("{}: {} {{ {} }}", #variant_value, #doc_or_desc, fields.join(", "))
1497 }
1498 });
1499 }
1500 }
1501 syn::Fields::Unnamed(fields) => {
1502 let field_count = fields.unnamed.len();
1504 let field_bindings: Vec<_> = (0..field_count)
1505 .map(|i| {
1506 syn::Ident::new(
1507 &format!("field{}", i),
1508 proc_macro2::Span::call_site(),
1509 )
1510 })
1511 .collect();
1512
1513 let field_displays: Vec<_> = field_bindings
1514 .iter()
1515 .map(|field_name| {
1516 quote! {
1517 format!("{:?}", #field_name)
1518 }
1519 })
1520 .collect();
1521
1522 let doc_or_desc = if let Some(desc) = &prompt_attrs.description {
1523 desc.clone()
1524 } else {
1525 let docs = extract_doc_comments(&variant.attrs);
1526 if !docs.is_empty() {
1527 docs
1528 } else {
1529 String::new()
1530 }
1531 };
1532
1533 if doc_or_desc.is_empty() {
1534 match_arms.push(quote! {
1535 Self::#variant_name(#(#field_bindings),*) => {
1536 let fields = vec![#(#field_displays),*];
1537 format!("{}({})", #variant_value, fields.join(", "))
1538 }
1539 });
1540 } else {
1541 match_arms.push(quote! {
1542 Self::#variant_name(#(#field_bindings),*) => {
1543 let fields = vec![#(#field_displays),*];
1544 format!("{}: {}({})", #variant_value, #doc_or_desc, fields.join(", "))
1545 }
1546 });
1547 }
1548 }
1549 }
1550 }
1551
1552 let to_prompt_impl = if match_arms.is_empty() {
1553 quote! {
1555 fn to_prompt(&self) -> String {
1556 match *self {}
1557 }
1558 }
1559 } else {
1560 quote! {
1561 fn to_prompt(&self) -> String {
1562 match self {
1563 #(#match_arms),*
1564 }
1565 }
1566 }
1567 };
1568
1569 let expanded = quote! {
1570 impl #impl_generics #crate_path::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
1571 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1572 vec![#crate_path::prompt::PromptPart::Text(self.to_prompt())]
1573 }
1574
1575 #to_prompt_impl
1576
1577 fn prompt_schema() -> String {
1578 #prompt_string
1579 }
1580 }
1581 };
1582
1583 TokenStream::from(expanded)
1584 }
1585 Data::Struct(data_struct) => {
1586 let mut template_attr = None;
1588 let mut template_file_attr = None;
1589 let mut mode_attr = None;
1590 let mut validate_attr = false;
1591 let mut type_marker_attr = false;
1592
1593 for attr in &input.attrs {
1594 if attr.path().is_ident("prompt") {
1595 if let Ok(metas) =
1597 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1598 {
1599 for meta in metas {
1600 match meta {
1601 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1602 if let syn::Expr::Lit(expr_lit) = nv.value
1603 && let syn::Lit::Str(lit_str) = expr_lit.lit
1604 {
1605 template_attr = Some(lit_str.value());
1606 }
1607 }
1608 Meta::NameValue(nv) if nv.path.is_ident("template_file") => {
1609 if let syn::Expr::Lit(expr_lit) = nv.value
1610 && let syn::Lit::Str(lit_str) = expr_lit.lit
1611 {
1612 template_file_attr = Some(lit_str.value());
1613 }
1614 }
1615 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
1616 if let syn::Expr::Lit(expr_lit) = nv.value
1617 && let syn::Lit::Str(lit_str) = expr_lit.lit
1618 {
1619 mode_attr = Some(lit_str.value());
1620 }
1621 }
1622 Meta::NameValue(nv) if nv.path.is_ident("validate") => {
1623 if let syn::Expr::Lit(expr_lit) = nv.value
1624 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
1625 {
1626 validate_attr = lit_bool.value();
1627 }
1628 }
1629 Meta::NameValue(nv) if nv.path.is_ident("type_marker") => {
1630 if let syn::Expr::Lit(expr_lit) = nv.value
1631 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
1632 {
1633 type_marker_attr = lit_bool.value();
1634 }
1635 }
1636 Meta::Path(path) if path.is_ident("type_marker") => {
1637 type_marker_attr = true;
1639 }
1640 _ => {}
1641 }
1642 }
1643 }
1644 }
1645 }
1646
1647 if template_attr.is_some() && template_file_attr.is_some() {
1649 return syn::Error::new(
1650 input.ident.span(),
1651 "The `template` and `template_file` attributes are mutually exclusive. Please use only one.",
1652 ).to_compile_error().into();
1653 }
1654
1655 let template_str = if let Some(file_path) = template_file_attr {
1657 let mut full_path = None;
1661
1662 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
1664 let is_trybuild = manifest_dir.contains("target/tests/trybuild");
1666
1667 if !is_trybuild {
1668 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
1670 if candidate.exists() {
1671 full_path = Some(candidate);
1672 }
1673 } else {
1674 if let Some(target_pos) = manifest_dir.find("/target/tests/trybuild") {
1680 let workspace_root = &manifest_dir[..target_pos];
1681 let original_macros_dir = std::path::Path::new(workspace_root)
1683 .join("crates")
1684 .join("llm-toolkit-macros");
1685
1686 let candidate = original_macros_dir.join(&file_path);
1687 if candidate.exists() {
1688 full_path = Some(candidate);
1689 }
1690 }
1691 }
1692 }
1693
1694 if full_path.is_none() {
1696 let candidate = std::path::Path::new(&file_path).to_path_buf();
1697 if candidate.exists() {
1698 full_path = Some(candidate);
1699 }
1700 }
1701
1702 if full_path.is_none()
1705 && let Ok(current_dir) = std::env::current_dir()
1706 {
1707 let mut search_dir = current_dir.as_path();
1708 for _ in 0..10 {
1710 let macros_dir = search_dir.join("crates/llm-toolkit-macros");
1712 if macros_dir.exists() {
1713 let candidate = macros_dir.join(&file_path);
1714 if candidate.exists() {
1715 full_path = Some(candidate);
1716 break;
1717 }
1718 }
1719 let candidate = search_dir.join(&file_path);
1721 if candidate.exists() {
1722 full_path = Some(candidate);
1723 break;
1724 }
1725 if let Some(parent) = search_dir.parent() {
1726 search_dir = parent;
1727 } else {
1728 break;
1729 }
1730 }
1731 }
1732
1733 if full_path.is_none() {
1735 let mut error_msg = format!(
1737 "Template file '{}' not found at compile time.\n\nSearched in:",
1738 file_path
1739 );
1740
1741 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
1742 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
1743 error_msg.push_str(&format!("\n - {}", candidate.display()));
1744 }
1745
1746 if let Ok(current_dir) = std::env::current_dir() {
1747 let candidate = current_dir.join(&file_path);
1748 error_msg.push_str(&format!("\n - {}", candidate.display()));
1749 }
1750
1751 error_msg.push_str("\n\nPlease ensure:");
1752 error_msg.push_str("\n 1. The template file exists");
1753 error_msg.push_str("\n 2. The path is relative to CARGO_MANIFEST_DIR");
1754 error_msg.push_str("\n 3. There are no typos in the path");
1755
1756 return syn::Error::new(input.ident.span(), error_msg)
1757 .to_compile_error()
1758 .into();
1759 }
1760
1761 let final_path = full_path.unwrap();
1762
1763 match std::fs::read_to_string(&final_path) {
1765 Ok(content) => Some(content),
1766 Err(e) => {
1767 return syn::Error::new(
1768 input.ident.span(),
1769 format!(
1770 "Failed to read template file '{}': {}\n\nPath resolved to: {}",
1771 file_path,
1772 e,
1773 final_path.display()
1774 ),
1775 )
1776 .to_compile_error()
1777 .into();
1778 }
1779 }
1780 } else {
1781 template_attr
1782 };
1783
1784 if validate_attr && let Some(template) = &template_str {
1786 let mut env = minijinja::Environment::new();
1788 if let Err(e) = env.add_template("validation", template) {
1789 let warning_msg =
1791 format!("Template validation warning: Invalid Jinja syntax - {}", e);
1792 let warning_ident = syn::Ident::new(
1793 "TEMPLATE_VALIDATION_WARNING",
1794 proc_macro2::Span::call_site(),
1795 );
1796 let _warning_tokens = quote! {
1797 #[deprecated(note = #warning_msg)]
1798 const #warning_ident: () = ();
1799 let _ = #warning_ident;
1800 };
1801 eprintln!("cargo:warning={}", warning_msg);
1803 }
1804
1805 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1807 &fields.named
1808 } else {
1809 panic!("Template validation is only supported for structs with named fields.");
1810 };
1811
1812 let field_names: std::collections::HashSet<String> = fields
1813 .iter()
1814 .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
1815 .collect();
1816
1817 let placeholders = parse_template_placeholders_with_mode(template);
1819
1820 for (placeholder_name, _mode) in &placeholders {
1821 if placeholder_name != "self" && !field_names.contains(placeholder_name) {
1822 let warning_msg = format!(
1823 "Template validation warning: Variable '{}' used in template but not found in struct fields",
1824 placeholder_name
1825 );
1826 eprintln!("cargo:warning={}", warning_msg);
1827 }
1828 }
1829 }
1830
1831 let name = input.ident;
1832 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1833
1834 let struct_docs = extract_doc_comments(&input.attrs);
1836
1837 let is_mode_based =
1839 mode_attr.is_some() || (template_str.is_none() && struct_docs.contains("mode"));
1840
1841 let expanded = if is_mode_based || mode_attr.is_some() {
1842 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1844 &fields.named
1845 } else {
1846 panic!(
1847 "Mode-based prompt generation is only supported for structs with named fields."
1848 );
1849 };
1850
1851 let struct_name_str = name.to_string();
1852
1853 let has_default = input.attrs.iter().any(|attr| {
1855 if attr.path().is_ident("derive")
1856 && let Ok(meta_list) = attr.meta.require_list()
1857 {
1858 let tokens_str = meta_list.tokens.to_string();
1859 tokens_str.contains("Default")
1860 } else {
1861 false
1862 }
1863 });
1864
1865 let schema_parts = generate_schema_only_parts(
1876 &struct_name_str,
1877 &struct_docs,
1878 fields,
1879 &crate_path,
1880 type_marker_attr,
1881 );
1882
1883 let example_parts = generate_example_only_parts(fields, has_default, &crate_path);
1885
1886 quote! {
1887 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1888 fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<#crate_path::prompt::PromptPart> {
1889 match mode {
1890 "schema_only" => #schema_parts,
1891 "example_only" => #example_parts,
1892 "full" | _ => {
1893 let mut parts = Vec::new();
1895
1896 let schema_parts = #schema_parts;
1898 parts.extend(schema_parts);
1899
1900 parts.push(#crate_path::prompt::PromptPart::Text("\n### Example".to_string()));
1902 parts.push(#crate_path::prompt::PromptPart::Text(
1903 format!("Here is an example of a valid `{}` object:", #struct_name_str)
1904 ));
1905
1906 let example_parts = #example_parts;
1908 parts.extend(example_parts);
1909
1910 parts
1911 }
1912 }
1913 }
1914
1915 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1916 self.to_prompt_parts_with_mode("full")
1917 }
1918
1919 fn to_prompt(&self) -> String {
1920 self.to_prompt_parts()
1921 .into_iter()
1922 .filter_map(|part| match part {
1923 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1924 _ => None,
1925 })
1926 .collect::<Vec<_>>()
1927 .join("\n")
1928 }
1929
1930 fn prompt_schema() -> String {
1931 use std::sync::OnceLock;
1932 static SCHEMA_CACHE: OnceLock<String> = OnceLock::new();
1933
1934 SCHEMA_CACHE.get_or_init(|| {
1935 let schema_parts = #schema_parts;
1936 schema_parts
1937 .into_iter()
1938 .filter_map(|part| match part {
1939 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1940 _ => None,
1941 })
1942 .collect::<Vec<_>>()
1943 .join("\n")
1944 }).clone()
1945 }
1946 }
1947 }
1948 } else if let Some(template) = template_str {
1949 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1952 &fields.named
1953 } else {
1954 panic!(
1955 "Template prompt generation is only supported for structs with named fields."
1956 );
1957 };
1958
1959 let struct_name_str = name.to_string();
1961 let schema_parts = generate_schema_only_parts(
1962 &struct_name_str,
1963 &struct_docs,
1964 fields,
1965 &crate_path,
1966 type_marker_attr,
1967 );
1968
1969 let placeholders = parse_template_placeholders_with_mode(&template);
1971 let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
1973 mode.is_some()
1974 && fields
1975 .iter()
1976 .any(|f| f.ident.as_ref().unwrap() == field_name)
1977 });
1978
1979 let mut image_field_parts = Vec::new();
1980 for f in fields.iter() {
1981 let field_name = f.ident.as_ref().unwrap();
1982 let attrs = parse_field_prompt_attrs(&f.attrs);
1983
1984 if attrs.image {
1985 image_field_parts.push(quote! {
1987 parts.extend(self.#field_name.to_prompt_parts());
1988 });
1989 }
1990 }
1991
1992 if has_mode_syntax {
1994 let mut context_fields = Vec::new();
1996 let mut modified_template = template.clone();
1997
1998 for (field_name, mode_opt) in &placeholders {
2000 if let Some(mode) = mode_opt {
2001 let unique_key = format!("{}__{}", field_name, mode);
2003
2004 let pattern = format!("{{{{ {}:{} }}}}", field_name, mode);
2006 let replacement = format!("{{{{ {} }}}}", unique_key);
2007 modified_template = modified_template.replace(&pattern, &replacement);
2008
2009 let field_ident =
2011 syn::Ident::new(field_name, proc_macro2::Span::call_site());
2012
2013 context_fields.push(quote! {
2015 context.insert(
2016 #unique_key.to_string(),
2017 #crate_path::minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
2018 );
2019 });
2020 }
2021 }
2022
2023 for field in fields.iter() {
2025 let field_name = field.ident.as_ref().unwrap();
2026 let field_name_str = field_name.to_string();
2027
2028 let has_mode_entry = placeholders
2030 .iter()
2031 .any(|(name, mode)| name == &field_name_str && mode.is_some());
2032
2033 if !has_mode_entry {
2034 match &field.ty {
2036 syn::Type::Path(type_path) => {
2037 if let Some(segment) = type_path.path.segments.last() {
2038 let type_name = segment.ident.to_string();
2039
2040 let is_primitive = matches!(
2042 type_name.as_str(),
2043 "String"
2044 | "str"
2045 | "i8"
2046 | "i16"
2047 | "i32"
2048 | "i64"
2049 | "i128"
2050 | "isize"
2051 | "u8"
2052 | "u16"
2053 | "u32"
2054 | "u64"
2055 | "u128"
2056 | "usize"
2057 | "f32"
2058 | "f64"
2059 | "bool"
2060 | "char"
2061 );
2062
2063 if is_primitive {
2064 context_fields.push(quote! {
2066 context.insert(
2067 #field_name_str.to_string(),
2068 #crate_path::minijinja::Value::from_serialize(&self.#field_name)
2069 );
2070 });
2071 } else if type_name == "Option" {
2072 let args = &segment.arguments;
2075 let is_option_vec =
2076 if let syn::PathArguments::AngleBracketed(
2077 angle_args,
2078 ) = args
2079 {
2080 if let Some(syn::GenericArgument::Type(
2081 syn::Type::Path(inner_path),
2082 )) = angle_args.args.first()
2083 {
2084 if let Some(inner_seg) =
2085 inner_path.path.segments.last()
2086 {
2087 inner_seg.ident == "Vec"
2088 } else {
2089 false
2090 }
2091 } else {
2092 false
2093 }
2094 } else {
2095 false
2096 };
2097
2098 if is_option_vec {
2099 context_fields.push(quote! {
2101 context.insert(
2102 #field_name_str.to_string(),
2103 match &self.#field_name {
2104 Some(vec) => {
2105 use #crate_path::prompt::ToPrompt;
2106 let prompt_items: Vec<String> = vec.iter()
2107 .map(|item| item.to_prompt())
2108 .collect();
2109 #crate_path::minijinja::Value::from_serialize(&Some(prompt_items))
2110 }
2111 None => #crate_path::minijinja::Value::from_serialize(&None::<Vec<String>>),
2112 }
2113 );
2114 });
2115 } else {
2116 context_fields.push(quote! {
2118 context.insert(
2119 #field_name_str.to_string(),
2120 match &self.#field_name {
2121 Some(inner) => {
2122 use #crate_path::prompt::ToPrompt;
2123 #crate_path::minijinja::Value::from(inner.to_prompt())
2124 }
2125 None => #crate_path::minijinja::Value::from_serialize(&None::<()>),
2126 }
2127 );
2128 });
2129 }
2130 } else if type_name == "Vec" {
2131 context_fields.push(quote! {
2134 context.insert(
2135 #field_name_str.to_string(),
2136 {
2137 use #crate_path::prompt::ToPrompt;
2138 let prompt_items: Vec<String> = self.#field_name.iter()
2141 .map(|item| item.to_prompt())
2142 .collect();
2143 #crate_path::minijinja::Value::from_serialize(&prompt_items)
2144 }
2145 );
2146 });
2147 } else {
2148 context_fields.push(quote! {
2150 context.insert(
2151 #field_name_str.to_string(),
2152 #crate_path::minijinja::Value::from(self.#field_name.to_prompt())
2153 );
2154 });
2155 }
2156 }
2157 }
2158 _ => {
2159 context_fields.push(quote! {
2161 context.insert(
2162 #field_name_str.to_string(),
2163 #crate_path::minijinja::Value::from(self.#field_name.to_prompt())
2164 );
2165 });
2166 }
2167 }
2168 }
2169 }
2170
2171 quote! {
2172 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
2173 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
2174 let mut parts = Vec::new();
2175
2176 #(#image_field_parts)*
2178
2179 let text = {
2181 let mut env = #crate_path::minijinja::Environment::new();
2182 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
2183 panic!("Failed to parse template: {}", e)
2184 });
2185
2186 let tmpl = env.get_template("prompt").unwrap();
2187
2188 let mut context = std::collections::HashMap::new();
2189 #(#context_fields)*
2190
2191 tmpl.render(context).unwrap_or_else(|e| {
2192 format!("Failed to render prompt: {}", e)
2193 })
2194 };
2195
2196 if !text.is_empty() {
2197 parts.push(#crate_path::prompt::PromptPart::Text(text));
2198 }
2199
2200 parts
2201 }
2202
2203 fn to_prompt(&self) -> String {
2204 let mut env = #crate_path::minijinja::Environment::new();
2206 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
2207 panic!("Failed to parse template: {}", e)
2208 });
2209
2210 let tmpl = env.get_template("prompt").unwrap();
2211
2212 let mut context = std::collections::HashMap::new();
2213 #(#context_fields)*
2214
2215 tmpl.render(context).unwrap_or_else(|e| {
2216 format!("Failed to render prompt: {}", e)
2217 })
2218 }
2219
2220 fn prompt_schema() -> String {
2221 use std::sync::OnceLock;
2222 static SCHEMA_CACHE: OnceLock<String> = OnceLock::new();
2223
2224 SCHEMA_CACHE.get_or_init(|| {
2225 let schema_parts = #schema_parts;
2226 schema_parts
2227 .into_iter()
2228 .filter_map(|part| match part {
2229 #crate_path::prompt::PromptPart::Text(text) => Some(text),
2230 _ => None,
2231 })
2232 .collect::<Vec<_>>()
2233 .join("\n")
2234 }).clone()
2235 }
2236 }
2237 }
2238 } else {
2239 let mut simple_context_fields = Vec::new();
2244 for field in fields.iter() {
2245 let field_name = field.ident.as_ref().unwrap();
2246 let field_name_str = field_name.to_string();
2247
2248 match &field.ty {
2250 syn::Type::Path(type_path) => {
2251 if let Some(segment) = type_path.path.segments.last() {
2252 let type_name = segment.ident.to_string();
2253
2254 let is_primitive = matches!(
2255 type_name.as_str(),
2256 "String"
2257 | "str"
2258 | "i8"
2259 | "i16"
2260 | "i32"
2261 | "i64"
2262 | "i128"
2263 | "isize"
2264 | "u8"
2265 | "u16"
2266 | "u32"
2267 | "u64"
2268 | "u128"
2269 | "usize"
2270 | "f32"
2271 | "f64"
2272 | "bool"
2273 | "char"
2274 );
2275
2276 if is_primitive {
2277 simple_context_fields.push(quote! {
2278 context.insert(
2279 #field_name_str.to_string(),
2280 #crate_path::minijinja::Value::from_serialize(&self.#field_name)
2281 );
2282 });
2283 } else if type_name == "Option" {
2284 let args = &segment.arguments;
2285 let is_option_vec =
2286 if let syn::PathArguments::AngleBracketed(angle_args) =
2287 args
2288 {
2289 if let Some(syn::GenericArgument::Type(
2290 syn::Type::Path(inner_path),
2291 )) = angle_args.args.first()
2292 {
2293 if let Some(inner_seg) =
2294 inner_path.path.segments.last()
2295 {
2296 inner_seg.ident == "Vec"
2297 } else {
2298 false
2299 }
2300 } else {
2301 false
2302 }
2303 } else {
2304 false
2305 };
2306
2307 if is_option_vec {
2308 simple_context_fields.push(quote! {
2309 context.insert(
2310 #field_name_str.to_string(),
2311 match &self.#field_name {
2312 Some(vec) => {
2313 use #crate_path::prompt::ToPrompt;
2314 let prompt_items: Vec<String> = vec.iter()
2315 .map(|item| item.to_prompt())
2316 .collect();
2317 #crate_path::minijinja::Value::from_serialize(&Some(prompt_items))
2318 }
2319 None => #crate_path::minijinja::Value::from_serialize(&None::<Vec<String>>),
2320 }
2321 );
2322 });
2323 } else {
2324 simple_context_fields.push(quote! {
2325 context.insert(
2326 #field_name_str.to_string(),
2327 match &self.#field_name {
2328 Some(inner) => {
2329 use #crate_path::prompt::ToPrompt;
2330 #crate_path::minijinja::Value::from(inner.to_prompt())
2331 }
2332 None => #crate_path::minijinja::Value::from_serialize(&None::<()>),
2333 }
2334 );
2335 });
2336 }
2337 } else if type_name == "Vec" {
2338 simple_context_fields.push(quote! {
2339 context.insert(
2340 #field_name_str.to_string(),
2341 {
2342 use #crate_path::prompt::ToPrompt;
2343 let prompt_items: Vec<String> = self.#field_name.iter()
2344 .map(|item| item.to_prompt())
2345 .collect();
2346 #crate_path::minijinja::Value::from_serialize(&prompt_items)
2347 }
2348 );
2349 });
2350 } else {
2351 simple_context_fields.push(quote! {
2352 context.insert(
2353 #field_name_str.to_string(),
2354 #crate_path::minijinja::Value::from(self.#field_name.to_prompt())
2355 );
2356 });
2357 }
2358 }
2359 }
2360 _ => {
2361 simple_context_fields.push(quote! {
2362 context.insert(
2363 #field_name_str.to_string(),
2364 #crate_path::minijinja::Value::from(self.#field_name.to_prompt())
2365 );
2366 });
2367 }
2368 }
2369 }
2370
2371 quote! {
2372 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
2373 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
2374 let mut parts = Vec::new();
2375
2376 #(#image_field_parts)*
2378
2379 let text = {
2381 let mut env = #crate_path::minijinja::Environment::new();
2382 env.add_template("prompt", #template).unwrap_or_else(|e| {
2383 panic!("Failed to parse template: {}", e)
2384 });
2385
2386 let tmpl = env.get_template("prompt").unwrap();
2387
2388 let mut context = std::collections::HashMap::new();
2389 #(#simple_context_fields)*
2390
2391 tmpl.render(context).unwrap_or_else(|e| {
2392 format!("Failed to render prompt: {}", e)
2393 })
2394 };
2395
2396 if !text.is_empty() {
2397 parts.push(#crate_path::prompt::PromptPart::Text(text));
2398 }
2399
2400 parts
2401 }
2402
2403 fn to_prompt(&self) -> String {
2404 let mut env = #crate_path::minijinja::Environment::new();
2406 env.add_template("prompt", #template).unwrap_or_else(|e| {
2407 panic!("Failed to parse template: {}", e)
2408 });
2409
2410 let tmpl = env.get_template("prompt").unwrap();
2411
2412 let mut context = std::collections::HashMap::new();
2413 #(#simple_context_fields)*
2414
2415 tmpl.render(context).unwrap_or_else(|e| {
2416 format!("Failed to render prompt: {}", e)
2417 })
2418 }
2419
2420 fn prompt_schema() -> String {
2421 use std::sync::OnceLock;
2422 static SCHEMA_CACHE: OnceLock<String> = OnceLock::new();
2423
2424 SCHEMA_CACHE.get_or_init(|| {
2425 let schema_parts = #schema_parts;
2426 schema_parts
2427 .into_iter()
2428 .filter_map(|part| match part {
2429 #crate_path::prompt::PromptPart::Text(text) => Some(text),
2430 _ => None,
2431 })
2432 .collect::<Vec<_>>()
2433 .join("\n")
2434 }).clone()
2435 }
2436 }
2437 }
2438 }
2439 } else {
2440 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
2443 &fields.named
2444 } else {
2445 panic!(
2446 "Default prompt generation is only supported for structs with named fields."
2447 );
2448 };
2449
2450 let mut text_field_parts = Vec::new();
2452 let mut image_field_parts = Vec::new();
2453
2454 for f in fields.iter() {
2455 let field_name = f.ident.as_ref().unwrap();
2456 let attrs = parse_field_prompt_attrs(&f.attrs);
2457
2458 if attrs.skip {
2460 continue;
2461 }
2462
2463 if attrs.image {
2464 image_field_parts.push(quote! {
2466 parts.extend(self.#field_name.to_prompt_parts());
2467 });
2468 } else {
2469 let key = if let Some(rename) = attrs.rename {
2475 rename
2476 } else {
2477 let doc_comment = extract_doc_comments(&f.attrs);
2478 if !doc_comment.is_empty() {
2479 doc_comment
2480 } else {
2481 field_name.to_string()
2482 }
2483 };
2484
2485 let value_expr = if let Some(format_with) = attrs.format_with {
2487 let func_path: syn::Path =
2489 syn::parse_str(&format_with).unwrap_or_else(|_| {
2490 panic!("Invalid function path: {}", format_with)
2491 });
2492 quote! { #func_path(&self.#field_name) }
2493 } else {
2494 quote! { self.#field_name.to_prompt() }
2495 };
2496
2497 text_field_parts.push(quote! {
2498 text_parts.push(format!("{}: {}", #key, #value_expr));
2499 });
2500 }
2501 }
2502
2503 let struct_name_str = name.to_string();
2505 let schema_parts = generate_schema_only_parts(
2506 &struct_name_str,
2507 &struct_docs,
2508 fields,
2509 &crate_path,
2510 false, );
2512
2513 quote! {
2515 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
2516 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
2517 let mut parts = Vec::new();
2518
2519 #(#image_field_parts)*
2521
2522 let mut text_parts = Vec::new();
2524 #(#text_field_parts)*
2525
2526 if !text_parts.is_empty() {
2527 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
2528 }
2529
2530 parts
2531 }
2532
2533 fn to_prompt(&self) -> String {
2534 let mut text_parts = Vec::new();
2535 #(#text_field_parts)*
2536 text_parts.join("\n")
2537 }
2538
2539 fn prompt_schema() -> String {
2540 use std::sync::OnceLock;
2541 static SCHEMA_CACHE: OnceLock<String> = OnceLock::new();
2542
2543 SCHEMA_CACHE.get_or_init(|| {
2544 let schema_parts = #schema_parts;
2545 schema_parts
2546 .into_iter()
2547 .filter_map(|part| match part {
2548 #crate_path::prompt::PromptPart::Text(text) => Some(text),
2549 _ => None,
2550 })
2551 .collect::<Vec<_>>()
2552 .join("\n")
2553 }).clone()
2554 }
2555 }
2556 }
2557 };
2558
2559 TokenStream::from(expanded)
2560 }
2561 Data::Union(_) => {
2562 panic!("`#[derive(ToPrompt)]` is not supported for unions");
2563 }
2564 }
2565}
2566
2567#[derive(Debug, Clone)]
2569struct TargetInfo {
2570 name: String,
2571 template: Option<String>,
2572 field_configs: std::collections::HashMap<String, FieldTargetConfig>,
2573}
2574
2575#[derive(Debug, Clone, Default)]
2577struct FieldTargetConfig {
2578 skip: bool,
2579 rename: Option<String>,
2580 format_with: Option<String>,
2581 image: bool,
2582 include_only: bool, }
2584
2585fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
2587 let mut configs = Vec::new();
2588
2589 for attr in attrs {
2590 if attr.path().is_ident("prompt_for")
2591 && let Ok(meta_list) = attr.meta.require_list()
2592 {
2593 if meta_list.tokens.to_string() == "skip" {
2595 let config = FieldTargetConfig {
2597 skip: true,
2598 ..Default::default()
2599 };
2600 configs.push(("*".to_string(), config));
2601 } else if let Ok(metas) =
2602 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2603 {
2604 let mut target_name = None;
2605 let mut config = FieldTargetConfig::default();
2606
2607 for meta in metas {
2608 match meta {
2609 Meta::NameValue(nv) if nv.path.is_ident("name") => {
2610 if let syn::Expr::Lit(syn::ExprLit {
2611 lit: syn::Lit::Str(lit_str),
2612 ..
2613 }) = nv.value
2614 {
2615 target_name = Some(lit_str.value());
2616 }
2617 }
2618 Meta::Path(path) if path.is_ident("skip") => {
2619 config.skip = true;
2620 }
2621 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
2622 if let syn::Expr::Lit(syn::ExprLit {
2623 lit: syn::Lit::Str(lit_str),
2624 ..
2625 }) = nv.value
2626 {
2627 config.rename = Some(lit_str.value());
2628 }
2629 }
2630 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
2631 if let syn::Expr::Lit(syn::ExprLit {
2632 lit: syn::Lit::Str(lit_str),
2633 ..
2634 }) = nv.value
2635 {
2636 config.format_with = Some(lit_str.value());
2637 }
2638 }
2639 Meta::Path(path) if path.is_ident("image") => {
2640 config.image = true;
2641 }
2642 _ => {}
2643 }
2644 }
2645
2646 if let Some(name) = target_name {
2647 config.include_only = true;
2648 configs.push((name, config));
2649 }
2650 }
2651 }
2652 }
2653
2654 configs
2655}
2656
2657fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
2659 let mut targets = Vec::new();
2660
2661 for attr in attrs {
2662 if attr.path().is_ident("prompt_for")
2663 && let Ok(meta_list) = attr.meta.require_list()
2664 && let Ok(metas) =
2665 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2666 {
2667 let mut target_name = None;
2668 let mut template = None;
2669
2670 for meta in metas {
2671 match meta {
2672 Meta::NameValue(nv) if nv.path.is_ident("name") => {
2673 if let syn::Expr::Lit(syn::ExprLit {
2674 lit: syn::Lit::Str(lit_str),
2675 ..
2676 }) = nv.value
2677 {
2678 target_name = Some(lit_str.value());
2679 }
2680 }
2681 Meta::NameValue(nv) if nv.path.is_ident("template") => {
2682 if let syn::Expr::Lit(syn::ExprLit {
2683 lit: syn::Lit::Str(lit_str),
2684 ..
2685 }) = nv.value
2686 {
2687 template = Some(lit_str.value());
2688 }
2689 }
2690 _ => {}
2691 }
2692 }
2693
2694 if let Some(name) = target_name {
2695 targets.push(TargetInfo {
2696 name,
2697 template,
2698 field_configs: std::collections::HashMap::new(),
2699 });
2700 }
2701 }
2702 }
2703
2704 targets
2705}
2706
2707#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
2708pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
2709 let input = parse_macro_input!(input as DeriveInput);
2710
2711 let found_crate =
2712 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2713 let crate_path = match found_crate {
2714 FoundCrate::Itself => {
2715 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2717 quote!(::#ident)
2718 }
2719 FoundCrate::Name(name) => {
2720 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2721 quote!(::#ident)
2722 }
2723 };
2724
2725 let data_struct = match &input.data {
2727 Data::Struct(data) => data,
2728 _ => {
2729 return syn::Error::new(
2730 input.ident.span(),
2731 "`#[derive(ToPromptSet)]` is only supported for structs",
2732 )
2733 .to_compile_error()
2734 .into();
2735 }
2736 };
2737
2738 let fields = match &data_struct.fields {
2739 syn::Fields::Named(fields) => &fields.named,
2740 _ => {
2741 return syn::Error::new(
2742 input.ident.span(),
2743 "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
2744 )
2745 .to_compile_error()
2746 .into();
2747 }
2748 };
2749
2750 let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
2752
2753 for field in fields.iter() {
2755 let field_name = field.ident.as_ref().unwrap().to_string();
2756 let field_configs = parse_prompt_for_attrs(&field.attrs);
2757
2758 for (target_name, config) in field_configs {
2759 if target_name == "*" {
2760 for target in &mut targets {
2762 target
2763 .field_configs
2764 .entry(field_name.clone())
2765 .or_insert_with(FieldTargetConfig::default)
2766 .skip = config.skip;
2767 }
2768 } else {
2769 let target_exists = targets.iter().any(|t| t.name == target_name);
2771 if !target_exists {
2772 targets.push(TargetInfo {
2774 name: target_name.clone(),
2775 template: None,
2776 field_configs: std::collections::HashMap::new(),
2777 });
2778 }
2779
2780 let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
2781
2782 target.field_configs.insert(field_name.clone(), config);
2783 }
2784 }
2785 }
2786
2787 let mut match_arms = Vec::new();
2789
2790 for target in &targets {
2791 let target_name = &target.name;
2792
2793 if let Some(template_str) = &target.template {
2794 let mut image_parts = Vec::new();
2796
2797 for field in fields.iter() {
2798 let field_name = field.ident.as_ref().unwrap();
2799 let field_name_str = field_name.to_string();
2800
2801 if let Some(config) = target.field_configs.get(&field_name_str)
2802 && config.image
2803 {
2804 image_parts.push(quote! {
2805 parts.extend(self.#field_name.to_prompt_parts());
2806 });
2807 }
2808 }
2809
2810 match_arms.push(quote! {
2811 #target_name => {
2812 let mut parts = Vec::new();
2813
2814 #(#image_parts)*
2815
2816 let text = #crate_path::prompt::render_prompt(#template_str, self)
2817 .map_err(|e| #crate_path::prompt::PromptSetError::RenderFailed {
2818 target: #target_name.to_string(),
2819 source: e,
2820 })?;
2821
2822 if !text.is_empty() {
2823 parts.push(#crate_path::prompt::PromptPart::Text(text));
2824 }
2825
2826 Ok(parts)
2827 }
2828 });
2829 } else {
2830 let mut text_field_parts = Vec::new();
2832 let mut image_field_parts = Vec::new();
2833
2834 for field in fields.iter() {
2835 let field_name = field.ident.as_ref().unwrap();
2836 let field_name_str = field_name.to_string();
2837
2838 let config = target.field_configs.get(&field_name_str);
2840
2841 if let Some(cfg) = config
2843 && cfg.skip
2844 {
2845 continue;
2846 }
2847
2848 let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
2852 let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
2853 .iter()
2854 .any(|(name, _)| name != "*");
2855
2856 if has_any_target_specific_config && !is_explicitly_for_this_target {
2857 continue;
2858 }
2859
2860 if let Some(cfg) = config {
2861 if cfg.image {
2862 image_field_parts.push(quote! {
2863 parts.extend(self.#field_name.to_prompt_parts());
2864 });
2865 } else {
2866 let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
2867
2868 let value_expr = if let Some(format_with) = &cfg.format_with {
2869 match syn::parse_str::<syn::Path>(format_with) {
2871 Ok(func_path) => quote! { #func_path(&self.#field_name) },
2872 Err(_) => {
2873 let error_msg = format!(
2875 "Invalid function path in format_with: '{}'",
2876 format_with
2877 );
2878 quote! {
2879 compile_error!(#error_msg);
2880 String::new()
2881 }
2882 }
2883 }
2884 } else {
2885 quote! { self.#field_name.to_prompt() }
2886 };
2887
2888 text_field_parts.push(quote! {
2889 text_parts.push(format!("{}: {}", #key, #value_expr));
2890 });
2891 }
2892 } else {
2893 text_field_parts.push(quote! {
2895 text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
2896 });
2897 }
2898 }
2899
2900 match_arms.push(quote! {
2901 #target_name => {
2902 let mut parts = Vec::new();
2903
2904 #(#image_field_parts)*
2905
2906 let mut text_parts = Vec::new();
2907 #(#text_field_parts)*
2908
2909 if !text_parts.is_empty() {
2910 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
2911 }
2912
2913 Ok(parts)
2914 }
2915 });
2916 }
2917 }
2918
2919 let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
2921
2922 match_arms.push(quote! {
2924 _ => {
2925 let available = vec![#(#target_names.to_string()),*];
2926 Err(#crate_path::prompt::PromptSetError::TargetNotFound {
2927 target: target.to_string(),
2928 available,
2929 })
2930 }
2931 });
2932
2933 let struct_name = &input.ident;
2934 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2935
2936 let expanded = quote! {
2937 impl #impl_generics #crate_path::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
2938 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<#crate_path::prompt::PromptPart>, #crate_path::prompt::PromptSetError> {
2939 match target {
2940 #(#match_arms)*
2941 }
2942 }
2943 }
2944 };
2945
2946 TokenStream::from(expanded)
2947}
2948
2949struct TypeList {
2951 types: Punctuated<syn::Type, Token![,]>,
2952}
2953
2954impl Parse for TypeList {
2955 fn parse(input: ParseStream) -> syn::Result<Self> {
2956 Ok(TypeList {
2957 types: Punctuated::parse_terminated(input)?,
2958 })
2959 }
2960}
2961
2962#[proc_macro]
2986pub fn examples_section(input: TokenStream) -> TokenStream {
2987 let input = parse_macro_input!(input as TypeList);
2988
2989 let found_crate =
2990 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2991 let _crate_path = match found_crate {
2992 FoundCrate::Itself => quote!(crate),
2993 FoundCrate::Name(name) => {
2994 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2995 quote!(::#ident)
2996 }
2997 };
2998
2999 let mut type_sections = Vec::new();
3001
3002 for ty in input.types.iter() {
3003 let type_name_str = quote!(#ty).to_string();
3005
3006 type_sections.push(quote! {
3008 {
3009 let type_name = #type_name_str;
3010 let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
3011 format!("---\n#### `{}`\n{}", type_name, json_example)
3012 }
3013 });
3014 }
3015
3016 let expanded = quote! {
3018 {
3019 let mut sections = Vec::new();
3020 sections.push("---".to_string());
3021 sections.push("### Examples".to_string());
3022 sections.push("".to_string());
3023 sections.push("Here are examples of the data structures you should use.".to_string());
3024 sections.push("".to_string());
3025
3026 #(sections.push(#type_sections);)*
3027
3028 sections.push("---".to_string());
3029
3030 sections.join("\n")
3031 }
3032 };
3033
3034 TokenStream::from(expanded)
3035}
3036
3037fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
3039 for attr in attrs {
3040 if attr.path().is_ident("prompt_for")
3041 && let Ok(meta_list) = attr.meta.require_list()
3042 && let Ok(metas) =
3043 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
3044 {
3045 let mut target_type = None;
3046 let mut template = None;
3047
3048 for meta in metas {
3049 match meta {
3050 Meta::NameValue(nv) if nv.path.is_ident("target") => {
3051 if let syn::Expr::Lit(syn::ExprLit {
3052 lit: syn::Lit::Str(lit_str),
3053 ..
3054 }) = nv.value
3055 {
3056 target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
3058 }
3059 }
3060 Meta::NameValue(nv) if nv.path.is_ident("template") => {
3061 if let syn::Expr::Lit(syn::ExprLit {
3062 lit: syn::Lit::Str(lit_str),
3063 ..
3064 }) = nv.value
3065 {
3066 template = Some(lit_str.value());
3067 }
3068 }
3069 _ => {}
3070 }
3071 }
3072
3073 if let (Some(target), Some(tmpl)) = (target_type, template) {
3074 return (target, tmpl);
3075 }
3076 }
3077 }
3078
3079 panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
3080}
3081
3082#[proc_macro_attribute]
3116pub fn define_intent(_attr: TokenStream, item: TokenStream) -> TokenStream {
3117 let input = parse_macro_input!(item as DeriveInput);
3118
3119 let found_crate =
3120 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3121 let crate_path = match found_crate {
3122 FoundCrate::Itself => {
3123 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3125 quote!(::#ident)
3126 }
3127 FoundCrate::Name(name) => {
3128 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3129 quote!(::#ident)
3130 }
3131 };
3132
3133 let enum_data = match &input.data {
3135 Data::Enum(data) => data,
3136 _ => {
3137 return syn::Error::new(
3138 input.ident.span(),
3139 "`#[define_intent]` can only be applied to enums",
3140 )
3141 .to_compile_error()
3142 .into();
3143 }
3144 };
3145
3146 let mut prompt_template = None;
3148 let mut extractor_tag = None;
3149 let mut mode = None;
3150
3151 for attr in &input.attrs {
3152 if attr.path().is_ident("intent")
3153 && let Ok(metas) =
3154 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
3155 {
3156 for meta in metas {
3157 match meta {
3158 Meta::NameValue(nv) if nv.path.is_ident("prompt") => {
3159 if let syn::Expr::Lit(syn::ExprLit {
3160 lit: syn::Lit::Str(lit_str),
3161 ..
3162 }) = nv.value
3163 {
3164 prompt_template = Some(lit_str.value());
3165 }
3166 }
3167 Meta::NameValue(nv) if nv.path.is_ident("extractor_tag") => {
3168 if let syn::Expr::Lit(syn::ExprLit {
3169 lit: syn::Lit::Str(lit_str),
3170 ..
3171 }) = nv.value
3172 {
3173 extractor_tag = Some(lit_str.value());
3174 }
3175 }
3176 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
3177 if let syn::Expr::Lit(syn::ExprLit {
3178 lit: syn::Lit::Str(lit_str),
3179 ..
3180 }) = nv.value
3181 {
3182 mode = Some(lit_str.value());
3183 }
3184 }
3185 _ => {}
3186 }
3187 }
3188 }
3189 }
3190
3191 let mode = mode.unwrap_or_else(|| "single".to_string());
3193
3194 if mode != "single" && mode != "multi_tag" {
3196 return syn::Error::new(
3197 input.ident.span(),
3198 "`mode` must be either \"single\" or \"multi_tag\"",
3199 )
3200 .to_compile_error()
3201 .into();
3202 }
3203
3204 let prompt_template = match prompt_template {
3206 Some(p) => p,
3207 None => {
3208 return syn::Error::new(
3209 input.ident.span(),
3210 "`#[intent(...)]` attribute must include `prompt = \"...\"`",
3211 )
3212 .to_compile_error()
3213 .into();
3214 }
3215 };
3216
3217 if mode == "multi_tag" {
3219 let enum_name = &input.ident;
3220 let actions_doc = generate_multi_tag_actions_doc(&enum_data.variants);
3221 return generate_multi_tag_output(
3222 &input,
3223 enum_name,
3224 enum_data,
3225 prompt_template,
3226 actions_doc,
3227 );
3228 }
3229
3230 let extractor_tag = match extractor_tag {
3232 Some(t) => t,
3233 None => {
3234 return syn::Error::new(
3235 input.ident.span(),
3236 "`#[intent(...)]` attribute must include `extractor_tag = \"...\"`",
3237 )
3238 .to_compile_error()
3239 .into();
3240 }
3241 };
3242
3243 let enum_name = &input.ident;
3245 let enum_docs = extract_doc_comments(&input.attrs);
3246
3247 let mut intents_doc_lines = Vec::new();
3248
3249 if !enum_docs.is_empty() {
3251 intents_doc_lines.push(format!("{}: {}", enum_name, enum_docs));
3252 } else {
3253 intents_doc_lines.push(format!("{}:", enum_name));
3254 }
3255 intents_doc_lines.push(String::new()); intents_doc_lines.push("Possible values:".to_string());
3257
3258 for variant in &enum_data.variants {
3260 let variant_name = &variant.ident;
3261 let variant_docs = extract_doc_comments(&variant.attrs);
3262
3263 if !variant_docs.is_empty() {
3264 intents_doc_lines.push(format!("- {}: {}", variant_name, variant_docs));
3265 } else {
3266 intents_doc_lines.push(format!("- {}", variant_name));
3267 }
3268 }
3269
3270 let intents_doc_str = intents_doc_lines.join("\n");
3271
3272 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
3274 let user_variables: Vec<String> = placeholders
3275 .iter()
3276 .filter_map(|(name, _)| {
3277 if name != "intents_doc" {
3278 Some(name.clone())
3279 } else {
3280 None
3281 }
3282 })
3283 .collect();
3284
3285 let enum_name_str = enum_name.to_string();
3287 let snake_case_name = to_snake_case(&enum_name_str);
3288 let function_name = syn::Ident::new(
3289 &format!("build_{}_prompt", snake_case_name),
3290 proc_macro2::Span::call_site(),
3291 );
3292
3293 let function_params: Vec<proc_macro2::TokenStream> = user_variables
3295 .iter()
3296 .map(|var| {
3297 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
3298 quote! { #ident: &str }
3299 })
3300 .collect();
3301
3302 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
3304 .iter()
3305 .map(|var| {
3306 let var_str = var.clone();
3307 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
3308 quote! {
3309 __template_context.insert(#var_str.to_string(), #crate_path::minijinja::Value::from(#ident));
3310 }
3311 })
3312 .collect();
3313
3314 let converted_template = prompt_template.clone();
3316
3317 let extractor_name = syn::Ident::new(
3319 &format!("{}Extractor", enum_name),
3320 proc_macro2::Span::call_site(),
3321 );
3322
3323 let filtered_attrs: Vec<_> = input
3325 .attrs
3326 .iter()
3327 .filter(|attr| !attr.path().is_ident("intent"))
3328 .collect();
3329
3330 let vis = &input.vis;
3332 let generics = &input.generics;
3333 let variants = &enum_data.variants;
3334 let enum_output = quote! {
3335 #(#filtered_attrs)*
3336 #vis enum #enum_name #generics {
3337 #variants
3338 }
3339 };
3340
3341 let expanded = quote! {
3343 #enum_output
3345
3346 pub fn #function_name(#(#function_params),*) -> String {
3348 let mut env = #crate_path::minijinja::Environment::new();
3349 env.add_template("prompt", #converted_template)
3350 .expect("Failed to parse intent prompt template");
3351
3352 let tmpl = env.get_template("prompt").unwrap();
3353
3354 let mut __template_context = std::collections::HashMap::new();
3355
3356 __template_context.insert("intents_doc".to_string(), #crate_path::minijinja::Value::from(#intents_doc_str));
3358
3359 #(#context_insertions)*
3361
3362 tmpl.render(&__template_context)
3363 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
3364 }
3365
3366 pub struct #extractor_name;
3368
3369 impl #extractor_name {
3370 pub const EXTRACTOR_TAG: &'static str = #extractor_tag;
3371 }
3372
3373 impl #crate_path::intent::IntentExtractor<#enum_name> for #extractor_name {
3374 fn extract_intent(&self, response: &str) -> Result<#enum_name, #crate_path::intent::IntentExtractionError> {
3375 #crate_path::intent::extract_intent_from_response(response, Self::EXTRACTOR_TAG)
3377 }
3378 }
3379 };
3380
3381 TokenStream::from(expanded)
3382}
3383
3384fn to_snake_case(s: &str) -> String {
3386 let mut result = String::new();
3387 let mut prev_upper = false;
3388
3389 for (i, ch) in s.chars().enumerate() {
3390 if ch.is_uppercase() {
3391 if i > 0 && !prev_upper {
3392 result.push('_');
3393 }
3394 result.push(ch.to_lowercase().next().unwrap());
3395 prev_upper = true;
3396 } else {
3397 result.push(ch);
3398 prev_upper = false;
3399 }
3400 }
3401
3402 result
3403}
3404
3405#[derive(Debug, Default)]
3407struct ActionAttrs {
3408 tag: Option<String>,
3409}
3410
3411fn parse_action_attrs(attrs: &[syn::Attribute]) -> ActionAttrs {
3412 let mut result = ActionAttrs::default();
3413
3414 for attr in attrs {
3415 if attr.path().is_ident("action")
3416 && let Ok(meta_list) = attr.meta.require_list()
3417 && let Ok(metas) =
3418 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
3419 {
3420 for meta in metas {
3421 if let Meta::NameValue(nv) = meta
3422 && nv.path.is_ident("tag")
3423 && let syn::Expr::Lit(syn::ExprLit {
3424 lit: syn::Lit::Str(lit_str),
3425 ..
3426 }) = nv.value
3427 {
3428 result.tag = Some(lit_str.value());
3429 }
3430 }
3431 }
3432 }
3433
3434 result
3435}
3436
3437#[derive(Debug, Default)]
3439struct FieldActionAttrs {
3440 is_attribute: bool,
3441 is_inner_text: bool,
3442}
3443
3444fn parse_field_action_attrs(attrs: &[syn::Attribute]) -> FieldActionAttrs {
3445 let mut result = FieldActionAttrs::default();
3446
3447 for attr in attrs {
3448 if attr.path().is_ident("action")
3449 && let Ok(meta_list) = attr.meta.require_list()
3450 {
3451 let tokens_str = meta_list.tokens.to_string();
3452 if tokens_str == "attribute" {
3453 result.is_attribute = true;
3454 } else if tokens_str == "inner_text" {
3455 result.is_inner_text = true;
3456 }
3457 }
3458 }
3459
3460 result
3461}
3462
3463fn generate_multi_tag_actions_doc(
3465 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
3466) -> String {
3467 let mut doc_lines = Vec::new();
3468
3469 for variant in variants {
3470 let action_attrs = parse_action_attrs(&variant.attrs);
3471
3472 if let Some(tag) = action_attrs.tag {
3473 let variant_docs = extract_doc_comments(&variant.attrs);
3474
3475 match &variant.fields {
3476 syn::Fields::Unit => {
3477 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
3479 }
3480 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
3481 doc_lines.push(format!("- `<{}>...</{}>`: {}", tag, tag, variant_docs));
3483 }
3484 syn::Fields::Named(fields) => {
3485 let mut attrs_str = Vec::new();
3487 let mut has_inner_text = false;
3488
3489 for field in &fields.named {
3490 let field_name = field.ident.as_ref().unwrap();
3491 let field_attrs = parse_field_action_attrs(&field.attrs);
3492
3493 if field_attrs.is_attribute {
3494 attrs_str.push(format!("{}=\"...\"", field_name));
3495 } else if field_attrs.is_inner_text {
3496 has_inner_text = true;
3497 }
3498 }
3499
3500 let attrs_part = if !attrs_str.is_empty() {
3501 format!(" {}", attrs_str.join(" "))
3502 } else {
3503 String::new()
3504 };
3505
3506 if has_inner_text {
3507 doc_lines.push(format!(
3508 "- `<{}{}>...</{}>`: {}",
3509 tag, attrs_part, tag, variant_docs
3510 ));
3511 } else if !attrs_str.is_empty() {
3512 doc_lines.push(format!("- `<{}{} />`: {}", tag, attrs_part, variant_docs));
3513 } else {
3514 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
3515 }
3516
3517 for field in &fields.named {
3519 let field_name = field.ident.as_ref().unwrap();
3520 let field_attrs = parse_field_action_attrs(&field.attrs);
3521 let field_docs = extract_doc_comments(&field.attrs);
3522
3523 if field_attrs.is_attribute {
3524 doc_lines
3525 .push(format!(" - `{}` (attribute): {}", field_name, field_docs));
3526 } else if field_attrs.is_inner_text {
3527 doc_lines
3528 .push(format!(" - `{}` (inner_text): {}", field_name, field_docs));
3529 }
3530 }
3531 }
3532 _ => {
3533 }
3535 }
3536 }
3537 }
3538
3539 doc_lines.join("\n")
3540}
3541
3542fn generate_tags_regex(
3544 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
3545) -> String {
3546 let mut tag_names = Vec::new();
3547
3548 for variant in variants {
3549 let action_attrs = parse_action_attrs(&variant.attrs);
3550 if let Some(tag) = action_attrs.tag {
3551 tag_names.push(tag);
3552 }
3553 }
3554
3555 if tag_names.is_empty() {
3556 return String::new();
3557 }
3558
3559 let tags_pattern = tag_names.join("|");
3560 format!(
3563 r"(?is)<(?:{})\b[^>]*/>|<(?:{})\b[^>]*>.*?</(?:{})>",
3564 tags_pattern, tags_pattern, tags_pattern
3565 )
3566}
3567
3568fn generate_multi_tag_output(
3570 input: &DeriveInput,
3571 enum_name: &syn::Ident,
3572 enum_data: &syn::DataEnum,
3573 prompt_template: String,
3574 actions_doc: String,
3575) -> TokenStream {
3576 let found_crate =
3577 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3578 let crate_path = match found_crate {
3579 FoundCrate::Itself => {
3580 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3582 quote!(::#ident)
3583 }
3584 FoundCrate::Name(name) => {
3585 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3586 quote!(::#ident)
3587 }
3588 };
3589
3590 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
3592 let user_variables: Vec<String> = placeholders
3593 .iter()
3594 .filter_map(|(name, _)| {
3595 if name != "actions_doc" {
3596 Some(name.clone())
3597 } else {
3598 None
3599 }
3600 })
3601 .collect();
3602
3603 let enum_name_str = enum_name.to_string();
3605 let snake_case_name = to_snake_case(&enum_name_str);
3606 let function_name = syn::Ident::new(
3607 &format!("build_{}_prompt", snake_case_name),
3608 proc_macro2::Span::call_site(),
3609 );
3610
3611 let function_params: Vec<proc_macro2::TokenStream> = user_variables
3613 .iter()
3614 .map(|var| {
3615 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
3616 quote! { #ident: &str }
3617 })
3618 .collect();
3619
3620 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
3622 .iter()
3623 .map(|var| {
3624 let var_str = var.clone();
3625 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
3626 quote! {
3627 __template_context.insert(#var_str.to_string(), #crate_path::minijinja::Value::from(#ident));
3628 }
3629 })
3630 .collect();
3631
3632 let extractor_name = syn::Ident::new(
3634 &format!("{}Extractor", enum_name),
3635 proc_macro2::Span::call_site(),
3636 );
3637
3638 let filtered_attrs: Vec<_> = input
3640 .attrs
3641 .iter()
3642 .filter(|attr| !attr.path().is_ident("intent"))
3643 .collect();
3644
3645 let filtered_variants: Vec<proc_macro2::TokenStream> = enum_data
3647 .variants
3648 .iter()
3649 .map(|variant| {
3650 let variant_name = &variant.ident;
3651 let variant_attrs: Vec<_> = variant
3652 .attrs
3653 .iter()
3654 .filter(|attr| !attr.path().is_ident("action"))
3655 .collect();
3656 let fields = &variant.fields;
3657
3658 let filtered_fields = match fields {
3660 syn::Fields::Named(named_fields) => {
3661 let filtered: Vec<_> = named_fields
3662 .named
3663 .iter()
3664 .map(|field| {
3665 let field_name = &field.ident;
3666 let field_type = &field.ty;
3667 let field_vis = &field.vis;
3668 let filtered_attrs: Vec<_> = field
3669 .attrs
3670 .iter()
3671 .filter(|attr| !attr.path().is_ident("action"))
3672 .collect();
3673 quote! {
3674 #(#filtered_attrs)*
3675 #field_vis #field_name: #field_type
3676 }
3677 })
3678 .collect();
3679 quote! { { #(#filtered,)* } }
3680 }
3681 syn::Fields::Unnamed(unnamed_fields) => {
3682 let types: Vec<_> = unnamed_fields
3683 .unnamed
3684 .iter()
3685 .map(|field| {
3686 let field_type = &field.ty;
3687 quote! { #field_type }
3688 })
3689 .collect();
3690 quote! { (#(#types),*) }
3691 }
3692 syn::Fields::Unit => quote! {},
3693 };
3694
3695 quote! {
3696 #(#variant_attrs)*
3697 #variant_name #filtered_fields
3698 }
3699 })
3700 .collect();
3701
3702 let vis = &input.vis;
3703 let generics = &input.generics;
3704
3705 let parsing_arms = generate_parsing_arms(&enum_data.variants, enum_name);
3707
3708 let tags_regex = generate_tags_regex(&enum_data.variants);
3710
3711 let expanded = quote! {
3712 #(#filtered_attrs)*
3714 #vis enum #enum_name #generics {
3715 #(#filtered_variants),*
3716 }
3717
3718 pub fn #function_name(#(#function_params),*) -> String {
3720 let mut env = #crate_path::minijinja::Environment::new();
3721 env.add_template("prompt", #prompt_template)
3722 .expect("Failed to parse intent prompt template");
3723
3724 let tmpl = env.get_template("prompt").unwrap();
3725
3726 let mut __template_context = std::collections::HashMap::new();
3727
3728 __template_context.insert("actions_doc".to_string(), #crate_path::minijinja::Value::from(#actions_doc));
3730
3731 #(#context_insertions)*
3733
3734 tmpl.render(&__template_context)
3735 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
3736 }
3737
3738 pub struct #extractor_name;
3740
3741 impl #extractor_name {
3742 fn parse_single_action(&self, text: &str) -> Option<#enum_name> {
3743 use ::quick_xml::events::Event;
3744 use ::quick_xml::Reader;
3745
3746 let mut actions = Vec::new();
3747 let mut reader = Reader::from_str(text);
3748 reader.config_mut().trim_text(true);
3749
3750 let mut buf = Vec::new();
3751
3752 loop {
3753 match reader.read_event_into(&mut buf) {
3754 Ok(Event::Start(e)) => {
3755 let owned_e = e.into_owned();
3756 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
3757 let is_empty = false;
3758
3759 #parsing_arms
3760 }
3761 Ok(Event::Empty(e)) => {
3762 let owned_e = e.into_owned();
3763 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
3764 let is_empty = true;
3765
3766 #parsing_arms
3767 }
3768 Ok(Event::Eof) => break,
3769 Err(_) => {
3770 break;
3772 }
3773 _ => {}
3774 }
3775 buf.clear();
3776 }
3777
3778 actions.into_iter().next()
3779 }
3780
3781 pub fn extract_actions(&self, text: &str) -> Result<Vec<#enum_name>, #crate_path::intent::IntentError> {
3782 use ::quick_xml::events::Event;
3783 use ::quick_xml::Reader;
3784
3785 let mut actions = Vec::new();
3786 let mut reader = Reader::from_str(text);
3787 reader.config_mut().trim_text(true);
3788
3789 let mut buf = Vec::new();
3790
3791 loop {
3792 match reader.read_event_into(&mut buf) {
3793 Ok(Event::Start(e)) => {
3794 let owned_e = e.into_owned();
3795 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
3796 let is_empty = false;
3797
3798 #parsing_arms
3799 }
3800 Ok(Event::Empty(e)) => {
3801 let owned_e = e.into_owned();
3802 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
3803 let is_empty = true;
3804
3805 #parsing_arms
3806 }
3807 Ok(Event::Eof) => break,
3808 Err(_) => {
3809 break;
3811 }
3812 _ => {}
3813 }
3814 buf.clear();
3815 }
3816
3817 Ok(actions)
3818 }
3819
3820 pub fn transform_actions<F>(&self, text: &str, mut transformer: F) -> String
3821 where
3822 F: FnMut(#enum_name) -> String,
3823 {
3824 use ::regex::Regex;
3825
3826 let regex_pattern = #tags_regex;
3827 if regex_pattern.is_empty() {
3828 return text.to_string();
3829 }
3830
3831 let re = Regex::new(®ex_pattern).unwrap_or_else(|e| {
3832 panic!("Failed to compile regex for action tags: {}", e);
3833 });
3834
3835 re.replace_all(text, |caps: &::regex::Captures| {
3836 let matched = caps.get(0).map(|m| m.as_str()).unwrap_or("");
3837
3838 if let Some(action) = self.parse_single_action(matched) {
3840 transformer(action)
3841 } else {
3842 matched.to_string()
3844 }
3845 }).to_string()
3846 }
3847
3848 pub fn strip_actions(&self, text: &str) -> String {
3849 self.transform_actions(text, |_| String::new())
3850 }
3851 }
3852 };
3853
3854 TokenStream::from(expanded)
3855}
3856
3857fn generate_parsing_arms(
3859 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
3860 enum_name: &syn::Ident,
3861) -> proc_macro2::TokenStream {
3862 let mut arms = Vec::new();
3863
3864 for variant in variants {
3865 let variant_name = &variant.ident;
3866 let action_attrs = parse_action_attrs(&variant.attrs);
3867
3868 if let Some(tag) = action_attrs.tag {
3869 match &variant.fields {
3870 syn::Fields::Unit => {
3871 arms.push(quote! {
3873 if &tag_name == #tag {
3874 actions.push(#enum_name::#variant_name);
3875 }
3876 });
3877 }
3878 syn::Fields::Unnamed(_fields) => {
3879 arms.push(quote! {
3881 if &tag_name == #tag && !is_empty {
3882 match reader.read_text(owned_e.name()) {
3884 Ok(text) => {
3885 actions.push(#enum_name::#variant_name(text.to_string()));
3886 }
3887 Err(_) => {
3888 actions.push(#enum_name::#variant_name(String::new()));
3890 }
3891 }
3892 }
3893 });
3894 }
3895 syn::Fields::Named(fields) => {
3896 let mut field_names = Vec::new();
3898 let mut has_inner_text_field = None;
3899
3900 for field in &fields.named {
3901 let field_name = field.ident.as_ref().unwrap();
3902 let field_attrs = parse_field_action_attrs(&field.attrs);
3903
3904 if field_attrs.is_attribute {
3905 field_names.push(field_name.clone());
3906 } else if field_attrs.is_inner_text {
3907 has_inner_text_field = Some(field_name.clone());
3908 }
3909 }
3910
3911 if let Some(inner_text_field) = has_inner_text_field {
3912 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
3915 quote! {
3916 let mut #field_name = String::new();
3917 for attr in owned_e.attributes() {
3918 if let Ok(attr) = attr {
3919 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
3920 #field_name = String::from_utf8_lossy(&attr.value).to_string();
3921 break;
3922 }
3923 }
3924 }
3925 }
3926 }).collect();
3927
3928 arms.push(quote! {
3929 if &tag_name == #tag {
3930 #(#attr_extractions)*
3931
3932 if is_empty {
3934 let #inner_text_field = String::new();
3935 actions.push(#enum_name::#variant_name {
3936 #(#field_names,)*
3937 #inner_text_field,
3938 });
3939 } else {
3940 match reader.read_text(owned_e.name()) {
3942 Ok(text) => {
3943 let #inner_text_field = text.to_string();
3944 actions.push(#enum_name::#variant_name {
3945 #(#field_names,)*
3946 #inner_text_field,
3947 });
3948 }
3949 Err(_) => {
3950 let #inner_text_field = String::new();
3952 actions.push(#enum_name::#variant_name {
3953 #(#field_names,)*
3954 #inner_text_field,
3955 });
3956 }
3957 }
3958 }
3959 }
3960 });
3961 } else {
3962 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
3964 quote! {
3965 let mut #field_name = String::new();
3966 for attr in owned_e.attributes() {
3967 if let Ok(attr) = attr {
3968 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
3969 #field_name = String::from_utf8_lossy(&attr.value).to_string();
3970 break;
3971 }
3972 }
3973 }
3974 }
3975 }).collect();
3976
3977 arms.push(quote! {
3978 if &tag_name == #tag {
3979 #(#attr_extractions)*
3980 actions.push(#enum_name::#variant_name {
3981 #(#field_names),*
3982 });
3983 }
3984 });
3985 }
3986 }
3987 }
3988 }
3989 }
3990
3991 quote! {
3992 #(#arms)*
3993 }
3994}
3995
3996#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
3998pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
3999 let input = parse_macro_input!(input as DeriveInput);
4000
4001 let found_crate =
4002 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
4003 let crate_path = match found_crate {
4004 FoundCrate::Itself => {
4005 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
4007 quote!(::#ident)
4008 }
4009 FoundCrate::Name(name) => {
4010 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
4011 quote!(::#ident)
4012 }
4013 };
4014
4015 let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
4017
4018 let struct_name = &input.ident;
4019 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
4020
4021 let placeholders = parse_template_placeholders_with_mode(&template);
4023
4024 let mut converted_template = template.clone();
4026 let mut context_fields = Vec::new();
4027
4028 let fields = match &input.data {
4030 Data::Struct(data_struct) => match &data_struct.fields {
4031 syn::Fields::Named(fields) => &fields.named,
4032 _ => panic!("ToPromptFor is only supported for structs with named fields"),
4033 },
4034 _ => panic!("ToPromptFor is only supported for structs"),
4035 };
4036
4037 let has_mode_support = input.attrs.iter().any(|attr| {
4039 if attr.path().is_ident("prompt")
4040 && let Ok(metas) =
4041 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
4042 {
4043 for meta in metas {
4044 if let Meta::NameValue(nv) = meta
4045 && nv.path.is_ident("mode")
4046 {
4047 return true;
4048 }
4049 }
4050 }
4051 false
4052 });
4053
4054 for (placeholder_name, mode_opt) in &placeholders {
4056 if placeholder_name == "self" {
4057 if let Some(specific_mode) = mode_opt {
4058 let unique_key = format!("self__{}", specific_mode);
4060
4061 let pattern = format!("{{{{ self:{} }}}}", specific_mode);
4063 let replacement = format!("{{{{ {} }}}}", unique_key);
4064 converted_template = converted_template.replace(&pattern, &replacement);
4065
4066 context_fields.push(quote! {
4068 context.insert(
4069 #unique_key.to_string(),
4070 #crate_path::minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
4071 );
4072 });
4073 } else {
4074 if has_mode_support {
4077 context_fields.push(quote! {
4079 context.insert(
4080 "self".to_string(),
4081 #crate_path::minijinja::Value::from(self.to_prompt_with_mode(mode))
4082 );
4083 });
4084 } else {
4085 context_fields.push(quote! {
4087 context.insert(
4088 "self".to_string(),
4089 #crate_path::minijinja::Value::from(self.to_prompt())
4090 );
4091 });
4092 }
4093 }
4094 } else {
4095 let field_exists = fields.iter().any(|f| {
4098 f.ident
4099 .as_ref()
4100 .is_some_and(|ident| ident == placeholder_name)
4101 });
4102
4103 if field_exists {
4104 let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
4105
4106 context_fields.push(quote! {
4110 context.insert(
4111 #placeholder_name.to_string(),
4112 #crate_path::minijinja::Value::from_serialize(&self.#field_ident)
4113 );
4114 });
4115 }
4116 }
4118 }
4119
4120 let expanded = quote! {
4121 impl #impl_generics #crate_path::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
4122 where
4123 #target_type: serde::Serialize,
4124 {
4125 fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
4126 let mut env = #crate_path::minijinja::Environment::new();
4128 env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
4129 panic!("Failed to parse template: {}", e)
4130 });
4131
4132 let tmpl = env.get_template("prompt").unwrap();
4133
4134 let mut context = std::collections::HashMap::new();
4136 context.insert(
4138 "self".to_string(),
4139 #crate_path::minijinja::Value::from_serialize(self)
4140 );
4141 context.insert(
4143 "target".to_string(),
4144 #crate_path::minijinja::Value::from_serialize(target)
4145 );
4146 #(#context_fields)*
4147
4148 tmpl.render(context).unwrap_or_else(|e| {
4150 format!("Failed to render prompt: {}", e)
4151 })
4152 }
4153 }
4154 };
4155
4156 TokenStream::from(expanded)
4157}
4158
4159struct AgentAttrs {
4165 expertise: Option<String>,
4166 output: Option<syn::Type>,
4167 backend: Option<String>,
4168 model: Option<String>,
4169 inner: Option<String>,
4170 default_inner: Option<String>,
4171 max_retries: Option<u32>,
4172 profile: Option<String>,
4173 init: Option<String>,
4174 proxy_methods: Option<Vec<String>>,
4175 persona: Option<syn::Expr>,
4176}
4177
4178impl Parse for AgentAttrs {
4179 fn parse(input: ParseStream) -> syn::Result<Self> {
4180 let mut expertise = None;
4181 let mut output = None;
4182 let mut backend = None;
4183 let mut model = None;
4184 let mut inner = None;
4185 let mut default_inner = None;
4186 let mut max_retries = None;
4187 let mut profile = None;
4188 let mut init = None;
4189 let mut proxy_methods = None;
4190 let mut persona = None;
4191
4192 let pairs = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
4193
4194 for meta in pairs {
4195 match meta {
4196 Meta::NameValue(nv) if nv.path.is_ident("expertise") => {
4197 if let syn::Expr::Lit(syn::ExprLit {
4198 lit: syn::Lit::Str(lit_str),
4199 ..
4200 }) = &nv.value
4201 {
4202 expertise = Some(lit_str.value());
4203 }
4204 }
4205 Meta::NameValue(nv) if nv.path.is_ident("output") => {
4206 if let syn::Expr::Lit(syn::ExprLit {
4207 lit: syn::Lit::Str(lit_str),
4208 ..
4209 }) = &nv.value
4210 {
4211 let ty: syn::Type = syn::parse_str(&lit_str.value())?;
4212 output = Some(ty);
4213 }
4214 }
4215 Meta::NameValue(nv) if nv.path.is_ident("backend") => {
4216 if let syn::Expr::Lit(syn::ExprLit {
4217 lit: syn::Lit::Str(lit_str),
4218 ..
4219 }) = &nv.value
4220 {
4221 backend = Some(lit_str.value());
4222 }
4223 }
4224 Meta::NameValue(nv) if nv.path.is_ident("model") => {
4225 if let syn::Expr::Lit(syn::ExprLit {
4226 lit: syn::Lit::Str(lit_str),
4227 ..
4228 }) = &nv.value
4229 {
4230 model = Some(lit_str.value());
4231 }
4232 }
4233 Meta::NameValue(nv) if nv.path.is_ident("inner") => {
4234 if let syn::Expr::Lit(syn::ExprLit {
4235 lit: syn::Lit::Str(lit_str),
4236 ..
4237 }) = &nv.value
4238 {
4239 inner = Some(lit_str.value());
4240 }
4241 }
4242 Meta::NameValue(nv) if nv.path.is_ident("default_inner") => {
4243 if let syn::Expr::Lit(syn::ExprLit {
4244 lit: syn::Lit::Str(lit_str),
4245 ..
4246 }) = &nv.value
4247 {
4248 default_inner = Some(lit_str.value());
4249 }
4250 }
4251 Meta::NameValue(nv) if nv.path.is_ident("max_retries") => {
4252 if let syn::Expr::Lit(syn::ExprLit {
4253 lit: syn::Lit::Int(lit_int),
4254 ..
4255 }) = &nv.value
4256 {
4257 max_retries = Some(lit_int.base10_parse()?);
4258 }
4259 }
4260 Meta::NameValue(nv) if nv.path.is_ident("profile") => {
4261 if let syn::Expr::Lit(syn::ExprLit {
4262 lit: syn::Lit::Str(lit_str),
4263 ..
4264 }) = &nv.value
4265 {
4266 profile = Some(lit_str.value());
4267 }
4268 }
4269 Meta::NameValue(nv) if nv.path.is_ident("init") => {
4270 if let syn::Expr::Lit(syn::ExprLit {
4271 lit: syn::Lit::Str(lit_str),
4272 ..
4273 }) = &nv.value
4274 {
4275 init = Some(lit_str.value());
4276 }
4277 }
4278 Meta::NameValue(nv) if nv.path.is_ident("proxy_methods") => {
4279 if let syn::Expr::Array(array) = &nv.value {
4280 let mut methods = Vec::new();
4281 for elem in &array.elems {
4282 if let syn::Expr::Lit(syn::ExprLit {
4283 lit: syn::Lit::Str(lit_str),
4284 ..
4285 }) = elem
4286 {
4287 methods.push(lit_str.value());
4288 }
4289 }
4290 proxy_methods = Some(methods);
4291 }
4292 }
4293 Meta::NameValue(nv) if nv.path.is_ident("persona") => {
4294 if let syn::Expr::Lit(syn::ExprLit {
4295 lit: syn::Lit::Str(lit_str),
4296 ..
4297 }) = &nv.value
4298 {
4299 let expr: syn::Expr = syn::parse_str(&lit_str.value())?;
4301 persona = Some(expr);
4302 }
4303 }
4304 _ => {}
4305 }
4306 }
4307
4308 Ok(AgentAttrs {
4309 expertise,
4310 output,
4311 backend,
4312 model,
4313 inner,
4314 default_inner,
4315 max_retries,
4316 profile,
4317 init,
4318 proxy_methods,
4319 persona,
4320 })
4321 }
4322}
4323
4324fn parse_agent_attrs(attrs: &[syn::Attribute]) -> syn::Result<AgentAttrs> {
4326 for attr in attrs {
4327 if attr.path().is_ident("agent") {
4328 return attr.parse_args::<AgentAttrs>();
4329 }
4330 }
4331
4332 Ok(AgentAttrs {
4333 expertise: None,
4334 output: None,
4335 backend: None,
4336 model: None,
4337 inner: None,
4338 default_inner: None,
4339 max_retries: None,
4340 profile: None,
4341 init: None,
4342 proxy_methods: None,
4343 persona: None,
4344 })
4345}
4346
4347fn generate_backend_constructors(
4349 struct_name: &syn::Ident,
4350 backend: &str,
4351 _model: Option<&str>,
4352 _profile: Option<&str>,
4353 crate_path: &proc_macro2::TokenStream,
4354) -> proc_macro2::TokenStream {
4355 match backend {
4356 "claude" => {
4357 quote! {
4358 impl #struct_name {
4359 pub fn with_claude() -> Self {
4361 Self::new(#crate_path::agent::impls::ClaudeCodeAgent::new())
4362 }
4363
4364 pub fn with_claude_model(model: &str) -> Self {
4366 Self::new(
4367 #crate_path::agent::impls::ClaudeCodeAgent::new()
4368 .with_model_str(model)
4369 )
4370 }
4371 }
4372 }
4373 }
4374 "gemini" => {
4375 quote! {
4376 impl #struct_name {
4377 pub fn with_gemini() -> Self {
4379 Self::new(#crate_path::agent::impls::GeminiAgent::new())
4380 }
4381
4382 pub fn with_gemini_model(model: &str) -> Self {
4384 Self::new(
4385 #crate_path::agent::impls::GeminiAgent::new()
4386 .with_model_str(model)
4387 )
4388 }
4389 }
4390 }
4391 }
4392 _ => quote! {},
4393 }
4394}
4395
4396fn generate_default_impl(
4398 struct_name: &syn::Ident,
4399 backend: &str,
4400 model: Option<&str>,
4401 profile: Option<&str>,
4402 crate_path: &proc_macro2::TokenStream,
4403) -> proc_macro2::TokenStream {
4404 let profile_expr = if let Some(profile_str) = profile {
4406 match profile_str.to_lowercase().as_str() {
4407 "creative" => quote! { #crate_path::agent::ExecutionProfile::Creative },
4408 "balanced" => quote! { #crate_path::agent::ExecutionProfile::Balanced },
4409 "deterministic" => quote! { #crate_path::agent::ExecutionProfile::Deterministic },
4410 _ => quote! { #crate_path::agent::ExecutionProfile::Balanced }, }
4412 } else {
4413 quote! { #crate_path::agent::ExecutionProfile::default() }
4414 };
4415
4416 let agent_init = match backend {
4417 "gemini" => {
4418 let mut builder = quote! { #crate_path::agent::impls::GeminiAgent::new() };
4419
4420 if let Some(model_str) = model {
4421 builder = quote! { #builder.with_model_str(#model_str) };
4422 }
4423
4424 builder = quote! { #builder.with_execution_profile(#profile_expr) };
4425 builder
4426 }
4427 _ => {
4428 let mut builder = quote! { #crate_path::agent::impls::ClaudeCodeAgent::new() };
4430
4431 if let Some(model_str) = model {
4432 builder = quote! { #builder.with_model_str(#model_str) };
4433 }
4434
4435 builder = quote! { #builder.with_execution_profile(#profile_expr) };
4436 builder
4437 }
4438 };
4439
4440 quote! {
4441 impl Default for #struct_name {
4442 fn default() -> Self {
4443 Self::new(#agent_init)
4444 }
4445 }
4446 }
4447}
4448
4449#[proc_macro_derive(Agent, attributes(agent))]
4458pub fn derive_agent(input: TokenStream) -> TokenStream {
4459 let input = parse_macro_input!(input as DeriveInput);
4460 let struct_name = &input.ident;
4461
4462 let agent_attrs = match parse_agent_attrs(&input.attrs) {
4464 Ok(attrs) => attrs,
4465 Err(e) => return e.to_compile_error().into(),
4466 };
4467
4468 let expertise = agent_attrs
4469 .expertise
4470 .unwrap_or_else(|| String::from("general AI assistant"));
4471 let output_type = agent_attrs
4472 .output
4473 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
4474 let backend = agent_attrs
4475 .backend
4476 .unwrap_or_else(|| String::from("claude"));
4477 let model = agent_attrs.model;
4478 let _profile = agent_attrs.profile; let max_retries = agent_attrs.max_retries.unwrap_or(3); let found_crate =
4483 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
4484 let crate_path = match found_crate {
4485 FoundCrate::Itself => {
4486 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
4488 quote!(::#ident)
4489 }
4490 FoundCrate::Name(name) => {
4491 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
4492 quote!(::#ident)
4493 }
4494 };
4495
4496 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
4497
4498 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
4500 let is_string_output = output_type_str == "String" || output_type_str == "&str";
4501
4502 let enhanced_expertise = if is_string_output {
4504 quote! { #expertise }
4506 } else {
4507 let type_name = quote!(#output_type).to_string();
4509 quote! {
4510 {
4511 use std::sync::OnceLock;
4512 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
4513
4514 EXPERTISE_CACHE.get_or_init(|| {
4515 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
4517
4518 if schema.is_empty() {
4519 format!(
4521 concat!(
4522 #expertise,
4523 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
4524 "Do not include any text outside the JSON object."
4525 ),
4526 #type_name
4527 )
4528 } else {
4529 format!(
4531 concat!(
4532 #expertise,
4533 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
4534 ),
4535 schema
4536 )
4537 }
4538 }).as_str()
4539 }
4540 }
4541 };
4542
4543 let agent_init = match backend.as_str() {
4545 "gemini" => {
4546 if let Some(model_str) = model {
4547 quote! {
4548 use #crate_path::agent::impls::GeminiAgent;
4549 let agent = GeminiAgent::new().with_model_str(#model_str);
4550 }
4551 } else {
4552 quote! {
4553 use #crate_path::agent::impls::GeminiAgent;
4554 let agent = GeminiAgent::new();
4555 }
4556 }
4557 }
4558 "claude" => {
4559 if let Some(model_str) = model {
4560 quote! {
4561 use #crate_path::agent::impls::ClaudeCodeAgent;
4562 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
4563 }
4564 } else {
4565 quote! {
4566 use #crate_path::agent::impls::ClaudeCodeAgent;
4567 let agent = ClaudeCodeAgent::new();
4568 }
4569 }
4570 }
4571 _ => {
4572 if let Some(model_str) = model {
4574 quote! {
4575 use #crate_path::agent::impls::ClaudeCodeAgent;
4576 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
4577 }
4578 } else {
4579 quote! {
4580 use #crate_path::agent::impls::ClaudeCodeAgent;
4581 let agent = ClaudeCodeAgent::new();
4582 }
4583 }
4584 }
4585 };
4586
4587 let expanded = quote! {
4588 #[async_trait::async_trait]
4589 impl #impl_generics #crate_path::agent::Agent for #struct_name #ty_generics #where_clause {
4590 type Output = #output_type;
4591
4592 fn expertise(&self) -> &str {
4593 #enhanced_expertise
4594 }
4595
4596 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
4597 #agent_init
4599
4600 let agent_ref = &agent;
4602 #crate_path::agent::retry::retry_execution(
4603 #max_retries,
4604 &intent,
4605 move |payload| {
4606 let payload = payload.clone();
4607 async move {
4608 let response = agent_ref.execute(payload).await?;
4610
4611 let json_str = #crate_path::extract_json(&response)
4613 .map_err(|e| #crate_path::agent::AgentError::ParseError {
4614 message: format!("Failed to extract JSON: {}", e),
4615 reason: #crate_path::agent::error::ParseErrorReason::MarkdownExtractionFailed,
4616 })?;
4617
4618 serde_json::from_str::<Self::Output>(&json_str)
4620 .map_err(|e| {
4621 let reason = if e.is_eof() {
4623 #crate_path::agent::error::ParseErrorReason::UnexpectedEof
4624 } else if e.is_syntax() {
4625 #crate_path::agent::error::ParseErrorReason::InvalidJson
4626 } else {
4627 #crate_path::agent::error::ParseErrorReason::SchemaMismatch
4628 };
4629
4630 #crate_path::agent::AgentError::ParseError {
4631 message: format!("Failed to parse JSON: {}", e),
4632 reason,
4633 }
4634 })
4635 }
4636 }
4637 ).await
4638 }
4639
4640 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
4641 #agent_init
4643 agent.is_available().await
4644 }
4645 }
4646 };
4647
4648 TokenStream::from(expanded)
4649}
4650
4651#[proc_macro_attribute]
4666pub fn agent(attr: TokenStream, item: TokenStream) -> TokenStream {
4667 let agent_attrs = match syn::parse::<AgentAttrs>(attr) {
4669 Ok(attrs) => attrs,
4670 Err(e) => return e.to_compile_error().into(),
4671 };
4672
4673 let input = parse_macro_input!(item as DeriveInput);
4675 let struct_name = &input.ident;
4676 let struct_name_str = struct_name.to_string();
4677 let vis = &input.vis;
4678
4679 let expertise = agent_attrs
4680 .expertise
4681 .unwrap_or_else(|| String::from("general AI assistant"));
4682 let output_type = agent_attrs
4683 .output
4684 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
4685 let backend = agent_attrs
4686 .backend
4687 .unwrap_or_else(|| String::from("claude"));
4688 let model = agent_attrs.model;
4689 let profile = agent_attrs.profile;
4690 let persona = agent_attrs.persona;
4691
4692 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
4694 let is_string_output = output_type_str == "String" || output_type_str == "&str";
4695
4696 let found_crate =
4698 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
4699 let crate_path = match found_crate {
4700 FoundCrate::Itself => {
4701 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
4702 quote!(::#ident)
4703 }
4704 FoundCrate::Name(name) => {
4705 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
4706 quote!(::#ident)
4707 }
4708 };
4709
4710 let inner_generic_name = agent_attrs.inner.unwrap_or_else(|| String::from("A"));
4712 let inner_generic_ident = syn::Ident::new(&inner_generic_name, proc_macro2::Span::call_site());
4713
4714 let default_agent_type = if let Some(ref custom_type) = agent_attrs.default_inner {
4716 let type_path: syn::Type =
4718 syn::parse_str(custom_type).expect("default_inner must be a valid type path");
4719 quote! { #type_path }
4720 } else {
4721 match backend.as_str() {
4723 "gemini" => quote! { #crate_path::agent::impls::GeminiAgent },
4724 _ => quote! { #crate_path::agent::impls::ClaudeCodeAgent },
4725 }
4726 };
4727
4728 let (struct_def, _actual_inner_type, uses_persona) = if let Some(ref _persona_path) = persona {
4730 let wrapped_type =
4733 quote! { #crate_path::agent::persona::PersonaAgent<#inner_generic_ident> };
4734 let struct_def = quote! {
4735 #vis struct #struct_name<#inner_generic_ident: #crate_path::agent::Agent + Send + Sync = #default_agent_type> {
4736 inner: #wrapped_type,
4737 }
4738 };
4739 (struct_def, wrapped_type, true)
4740 } else {
4741 let struct_def = quote! {
4743 #vis struct #struct_name<#inner_generic_ident = #default_agent_type> {
4744 inner: #inner_generic_ident,
4745 }
4746 };
4747 (struct_def, quote! { #inner_generic_ident }, false)
4748 };
4749
4750 let constructors = if let Some(ref persona_path) = persona {
4752 quote! {
4753 impl<#inner_generic_ident: #crate_path::agent::Agent + Send + Sync> #struct_name<#inner_generic_ident> {
4754 pub fn new(inner: #inner_generic_ident) -> Self {
4756 let persona_agent = #crate_path::agent::persona::PersonaAgent::new(
4757 inner,
4758 #persona_path.clone()
4759 );
4760 Self { inner: persona_agent }
4761 }
4762 }
4763 }
4764 } else {
4765 quote! {
4766 impl<#inner_generic_ident> #struct_name<#inner_generic_ident> {
4767 pub fn new(inner: #inner_generic_ident) -> Self {
4769 Self { inner }
4770 }
4771 }
4772 }
4773 };
4774
4775 let (backend_constructors, default_impl) = if let Some(ref _persona_path) = persona {
4777 let agent_init = match backend.as_str() {
4779 "gemini" => {
4780 let mut builder = quote! { #crate_path::agent::impls::GeminiAgent::new() };
4781 if let Some(model_str) = model.as_deref() {
4782 builder = quote! { #builder.with_model_str(#model_str) };
4783 }
4784 if let Some(profile_str) = profile.as_deref() {
4785 let profile_expr = match profile_str.to_lowercase().as_str() {
4786 "creative" => quote! { #crate_path::agent::ExecutionProfile::Creative },
4787 "balanced" => quote! { #crate_path::agent::ExecutionProfile::Balanced },
4788 "deterministic" => {
4789 quote! { #crate_path::agent::ExecutionProfile::Deterministic }
4790 }
4791 _ => quote! { #crate_path::agent::ExecutionProfile::Balanced },
4792 };
4793 builder = quote! { #builder.with_execution_profile(#profile_expr) };
4794 }
4795 builder
4796 }
4797 _ => {
4798 let mut builder = quote! { #crate_path::agent::impls::ClaudeCodeAgent::new() };
4799 if let Some(model_str) = model.as_deref() {
4800 builder = quote! { #builder.with_model_str(#model_str) };
4801 }
4802 if let Some(profile_str) = profile.as_deref() {
4803 let profile_expr = match profile_str.to_lowercase().as_str() {
4804 "creative" => quote! { #crate_path::agent::ExecutionProfile::Creative },
4805 "balanced" => quote! { #crate_path::agent::ExecutionProfile::Balanced },
4806 "deterministic" => {
4807 quote! { #crate_path::agent::ExecutionProfile::Deterministic }
4808 }
4809 _ => quote! { #crate_path::agent::ExecutionProfile::Balanced },
4810 };
4811 builder = quote! { #builder.with_execution_profile(#profile_expr) };
4812 }
4813 builder
4814 }
4815 };
4816
4817 let backend_constructors = match backend.as_str() {
4818 "claude" => {
4819 quote! {
4820 impl #struct_name {
4821 pub fn with_claude() -> Self {
4823 let base_agent = #crate_path::agent::impls::ClaudeCodeAgent::new();
4824 Self::new(base_agent)
4825 }
4826
4827 pub fn with_claude_model(model: &str) -> Self {
4829 let base_agent = #crate_path::agent::impls::ClaudeCodeAgent::new()
4830 .with_model_str(model);
4831 Self::new(base_agent)
4832 }
4833 }
4834 }
4835 }
4836 "gemini" => {
4837 quote! {
4838 impl #struct_name {
4839 pub fn with_gemini() -> Self {
4841 let base_agent = #crate_path::agent::impls::GeminiAgent::new();
4842 Self::new(base_agent)
4843 }
4844
4845 pub fn with_gemini_model(model: &str) -> Self {
4847 let base_agent = #crate_path::agent::impls::GeminiAgent::new()
4848 .with_model_str(model);
4849 Self::new(base_agent)
4850 }
4851 }
4852 }
4853 }
4854 _ => quote! {},
4855 };
4856
4857 let default_impl = quote! {
4858 impl Default for #struct_name {
4859 fn default() -> Self {
4860 let base_agent = #agent_init;
4861 Self::new(base_agent)
4862 }
4863 }
4864 };
4865
4866 (backend_constructors, default_impl)
4867 } else if agent_attrs.default_inner.is_some() {
4868 let default_impl = if let Some(init_fn) = &agent_attrs.init {
4870 let init_fn_ident: syn::Ident = syn::parse_str(init_fn).unwrap();
4872 quote! {
4873 impl Default for #struct_name {
4874 fn default() -> Self {
4875 let inner = <#default_agent_type as Default>::default();
4876 let inner = #init_fn_ident(inner);
4877 Self { inner }
4878 }
4879 }
4880 }
4881 } else {
4882 quote! {
4884 impl Default for #struct_name {
4885 fn default() -> Self {
4886 Self {
4887 inner: <#default_agent_type as Default>::default(),
4888 }
4889 }
4890 }
4891 }
4892 };
4893 (quote! {}, default_impl)
4894 } else {
4895 let backend_constructors = generate_backend_constructors(
4897 struct_name,
4898 &backend,
4899 model.as_deref(),
4900 profile.as_deref(),
4901 &crate_path,
4902 );
4903 let default_impl = generate_default_impl(
4904 struct_name,
4905 &backend,
4906 model.as_deref(),
4907 profile.as_deref(),
4908 &crate_path,
4909 );
4910 (backend_constructors, default_impl)
4911 };
4912
4913 let enhanced_expertise = if is_string_output {
4915 quote! { #expertise }
4917 } else {
4918 let type_name = quote!(#output_type).to_string();
4920 quote! {
4921 {
4922 use std::sync::OnceLock;
4923 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
4924
4925 EXPERTISE_CACHE.get_or_init(|| {
4926 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
4928
4929 if schema.is_empty() {
4930 format!(
4932 concat!(
4933 #expertise,
4934 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
4935 "Do not include any text outside the JSON object."
4936 ),
4937 #type_name
4938 )
4939 } else {
4940 format!(
4942 concat!(
4943 #expertise,
4944 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
4945 ),
4946 schema
4947 )
4948 }
4949 }).as_str()
4950 }
4951 }
4952 };
4953
4954 let agent_impl = if uses_persona {
4956 quote! {
4958 #[async_trait::async_trait]
4959 impl<#inner_generic_ident> #crate_path::agent::Agent for #struct_name<#inner_generic_ident>
4960 where
4961 #inner_generic_ident: #crate_path::agent::Agent + Send + Sync,
4962 <#inner_generic_ident as #crate_path::agent::Agent>::Output: Send,
4963 {
4964 type Output = <#inner_generic_ident as #crate_path::agent::Agent>::Output;
4965
4966 fn expertise(&self) -> &str {
4967 self.inner.expertise()
4968 }
4969
4970 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
4971 self.inner.execute(intent).await
4972 }
4973
4974 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
4975 self.inner.is_available().await
4976 }
4977 }
4978 }
4979 } else {
4980 quote! {
4982 #[async_trait::async_trait]
4983 impl<#inner_generic_ident> #crate_path::agent::Agent for #struct_name<#inner_generic_ident>
4984 where
4985 #inner_generic_ident: #crate_path::agent::Agent<Output = String>,
4986 {
4987 type Output = #output_type;
4988
4989 fn expertise(&self) -> &str {
4990 #enhanced_expertise
4991 }
4992
4993 #[#crate_path::tracing::instrument(name = "agent.execute", skip_all, fields(agent.name = #struct_name_str, agent.expertise = self.expertise()))]
4994 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
4995 let enhanced_payload = intent.prepend_text(self.expertise());
4997
4998 let response = self.inner.execute(enhanced_payload).await?;
5000
5001 let json_str = #crate_path::extract_json(&response)
5003 .map_err(|e| #crate_path::agent::AgentError::ParseError {
5004 message: e.to_string(),
5005 reason: #crate_path::agent::error::ParseErrorReason::MarkdownExtractionFailed,
5006 })?;
5007
5008 serde_json::from_str(&json_str).map_err(|e| {
5010 let reason = if e.is_eof() {
5011 #crate_path::agent::error::ParseErrorReason::UnexpectedEof
5012 } else if e.is_syntax() {
5013 #crate_path::agent::error::ParseErrorReason::InvalidJson
5014 } else {
5015 #crate_path::agent::error::ParseErrorReason::SchemaMismatch
5016 };
5017 #crate_path::agent::AgentError::ParseError {
5018 message: e.to_string(),
5019 reason,
5020 }
5021 })
5022 }
5023
5024 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
5025 self.inner.is_available().await
5026 }
5027 }
5028 }
5029 };
5030
5031 let proxy_methods = if let Some(ref methods) = agent_attrs.proxy_methods {
5033 let method_impls: Vec<proc_macro2::TokenStream> = methods
5036 .iter()
5037 .filter_map(|method_name| {
5038 match method_name.as_str() {
5039 "with_cwd" | "with_directory" | "with_attachment_dir" => {
5040 let method_ident = syn::Ident::new(method_name, proc_macro2::Span::call_site());
5041 Some(quote! {
5042 pub fn #method_ident(self, path: impl Into<std::path::PathBuf>) -> Self {
5043 let inner = self.inner.#method_ident(path);
5044 Self::new(inner)
5045 }
5046 })
5047 }
5048 "with_env" => {
5049 Some(quote! {
5050 pub fn with_env(self, key: impl Into<String>, value: impl Into<String>) -> Self {
5051 let inner = self.inner.with_env(key, value);
5052 Self::new(inner)
5053 }
5054 })
5055 }
5056 "with_envs" => {
5057 Some(quote! {
5058 pub fn with_envs(self, envs: std::collections::HashMap<String, String>) -> Self {
5059 let inner = self.inner.with_envs(envs);
5060 Self::new(inner)
5061 }
5062 })
5063 }
5064 "with_arg" => {
5065 Some(quote! {
5066 pub fn with_arg(self, arg: impl Into<String>) -> Self {
5067 let inner = self.inner.with_arg(arg);
5068 Self::new(inner)
5069 }
5070 })
5071 }
5072 "with_args" => {
5073 Some(quote! {
5074 pub fn with_args(self, args: Vec<String>) -> Self {
5075 let inner = self.inner.with_args(args);
5076 Self::new(inner)
5077 }
5078 })
5079 }
5080 "with_model_str" => {
5081 Some(quote! {
5082 pub fn with_model_str(self, model: &str) -> Self {
5083 let inner = self.inner.with_model_str(model);
5084 Self::new(inner)
5085 }
5086 })
5087 }
5088 "with_execution_profile" => {
5089 Some(quote! {
5090 pub fn with_execution_profile(self, profile: #crate_path::agent::ExecutionProfile) -> Self {
5091 let inner = self.inner.with_execution_profile(profile);
5092 Self::new(inner)
5093 }
5094 })
5095 }
5096 "with_keep_attachments" => {
5097 Some(quote! {
5098 pub fn with_keep_attachments(self, keep: bool) -> Self {
5099 let inner = self.inner.with_keep_attachments(keep);
5100 Self::new(inner)
5101 }
5102 })
5103 }
5104 _ => None, }
5106 })
5107 .collect();
5108
5109 if !method_impls.is_empty() {
5110 quote! {
5111 impl #struct_name {
5112 #(#method_impls)*
5113 }
5114 }
5115 } else {
5116 quote! {}
5117 }
5118 } else {
5119 quote! {}
5120 };
5121
5122 let expanded = quote! {
5123 #struct_def
5124 #constructors
5125 #backend_constructors
5126 #default_impl
5127 #proxy_methods
5128 #agent_impl
5129 };
5130
5131 TokenStream::from(expanded)
5132}
5133
5134#[proc_macro_derive(TypeMarker)]
5156pub fn derive_type_marker(input: TokenStream) -> TokenStream {
5157 let input = parse_macro_input!(input as DeriveInput);
5158 let struct_name = &input.ident;
5159 let type_name_str = struct_name.to_string();
5160
5161 let found_crate =
5163 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
5164 let crate_path = match found_crate {
5165 FoundCrate::Itself => {
5166 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
5167 quote!(::#ident)
5168 }
5169 FoundCrate::Name(name) => {
5170 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
5171 quote!(::#ident)
5172 }
5173 };
5174
5175 let expanded = quote! {
5176 impl #crate_path::orchestrator::TypeMarker for #struct_name {
5177 const TYPE_NAME: &'static str = #type_name_str;
5178 }
5179 };
5180
5181 TokenStream::from(expanded)
5182}
5183
5184#[proc_macro_attribute]
5220pub fn type_marker(_attr: TokenStream, item: TokenStream) -> TokenStream {
5221 let input = parse_macro_input!(item as syn::DeriveInput);
5222 let struct_name = &input.ident;
5223 let vis = &input.vis;
5224 let type_name_str = struct_name.to_string();
5225
5226 let default_fn_name = syn::Ident::new(
5228 &format!("default_{}_type", to_snake_case(&type_name_str)),
5229 struct_name.span(),
5230 );
5231
5232 let found_crate =
5234 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
5235 let crate_path = match found_crate {
5236 FoundCrate::Itself => {
5237 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
5238 quote!(::#ident)
5239 }
5240 FoundCrate::Name(name) => {
5241 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
5242 quote!(::#ident)
5243 }
5244 };
5245
5246 let fields = match &input.data {
5248 syn::Data::Struct(data_struct) => match &data_struct.fields {
5249 syn::Fields::Named(fields) => &fields.named,
5250 _ => {
5251 return syn::Error::new_spanned(
5252 struct_name,
5253 "type_marker only works with structs with named fields",
5254 )
5255 .to_compile_error()
5256 .into();
5257 }
5258 },
5259 _ => {
5260 return syn::Error::new_spanned(struct_name, "type_marker only works with structs")
5261 .to_compile_error()
5262 .into();
5263 }
5264 };
5265
5266 let mut new_fields = vec![];
5268
5269 let default_fn_name_str = default_fn_name.to_string();
5271 let default_fn_name_lit = syn::LitStr::new(&default_fn_name_str, default_fn_name.span());
5272
5273 new_fields.push(quote! {
5278 #[serde(default = #default_fn_name_lit)]
5279 __type: String
5280 });
5281
5282 for field in fields {
5284 new_fields.push(quote! { #field });
5285 }
5286
5287 let attrs = &input.attrs;
5289 let generics = &input.generics;
5290
5291 let expanded = quote! {
5292 fn #default_fn_name() -> String {
5294 #type_name_str.to_string()
5295 }
5296
5297 #(#attrs)*
5299 #vis struct #struct_name #generics {
5300 #(#new_fields),*
5301 }
5302
5303 impl #crate_path::orchestrator::TypeMarker for #struct_name {
5305 const TYPE_NAME: &'static str = #type_name_str;
5306 }
5307 };
5308
5309 TokenStream::from(expanded)
5310}