1use proc_macro::TokenStream;
2use proc_macro_crate::{FoundCrate, crate_name};
3use quote::quote;
4use regex::Regex;
5use syn::{
6 Data, DeriveInput, Meta, Token,
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 punctuated::Punctuated,
10};
11
12fn parse_template_placeholders_with_mode(template: &str) -> Vec<(String, Option<String>)> {
15 let mut placeholders = Vec::new();
16 let mut seen_fields = std::collections::HashSet::new();
17
18 let mode_pattern = Regex::new(r"\{\{\s*(\w+)\s*:\s*(\w+)\s*\}\}").unwrap();
20 for cap in mode_pattern.captures_iter(template) {
21 let field_name = cap[1].to_string();
22 let mode = cap[2].to_string();
23 placeholders.push((field_name.clone(), Some(mode)));
24 seen_fields.insert(field_name);
25 }
26
27 let standard_pattern = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
29 for cap in standard_pattern.captures_iter(template) {
30 let field_name = cap[1].to_string();
31 if !seen_fields.contains(&field_name) {
33 placeholders.push((field_name, None));
34 }
35 }
36
37 placeholders
38}
39
40fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
42 attrs
43 .iter()
44 .filter_map(|attr| {
45 if attr.path().is_ident("doc")
46 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
47 && let syn::Expr::Lit(syn::ExprLit {
48 lit: syn::Lit::Str(lit_str),
49 ..
50 }) = &meta_name_value.value
51 {
52 return Some(lit_str.value());
53 }
54 None
55 })
56 .map(|s| s.trim().to_string())
57 .collect::<Vec<_>>()
58 .join(" ")
59}
60
61fn generate_example_only_parts(
63 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
64 has_default: bool,
65 crate_path: &proc_macro2::TokenStream,
66) -> proc_macro2::TokenStream {
67 let mut field_values = Vec::new();
68
69 for field in fields.iter() {
70 let field_name = field.ident.as_ref().unwrap();
71 let field_name_str = field_name.to_string();
72 let attrs = parse_field_prompt_attrs(&field.attrs);
73
74 if field_name_str == "__type" {
78 continue;
79 }
80
81 if attrs.skip {
83 continue;
84 }
85
86 if let Some(example) = attrs.example {
88 field_values.push(quote! {
90 json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
91 });
92 } else if has_default {
93 field_values.push(quote! {
95 let default_value = serde_json::to_value(&default_instance.#field_name)
96 .unwrap_or(serde_json::Value::Null);
97 json_obj.insert(#field_name_str.to_string(), default_value);
98 });
99 } else {
100 field_values.push(quote! {
102 let value = serde_json::to_value(&self.#field_name)
103 .unwrap_or(serde_json::Value::Null);
104 json_obj.insert(#field_name_str.to_string(), value);
105 });
106 }
107 }
108
109 if has_default {
110 quote! {
111 {
112 let default_instance = Self::default();
113 let mut json_obj = serde_json::Map::new();
114 #(#field_values)*
115 let json_value = serde_json::Value::Object(json_obj);
116 let json_str = serde_json::to_string_pretty(&json_value)
117 .unwrap_or_else(|_| "{}".to_string());
118 vec![#crate_path::prompt::PromptPart::Text(json_str)]
119 }
120 }
121 } else {
122 quote! {
123 {
124 let mut json_obj = serde_json::Map::new();
125 #(#field_values)*
126 let json_value = serde_json::Value::Object(json_obj);
127 let json_str = serde_json::to_string_pretty(&json_value)
128 .unwrap_or_else(|_| "{}".to_string());
129 vec![#crate_path::prompt::PromptPart::Text(json_str)]
130 }
131 }
132 }
133}
134
135fn generate_schema_only_parts(
137 struct_name: &str,
138 struct_docs: &str,
139 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
140 crate_path: &proc_macro2::TokenStream,
141 _has_type_marker: bool,
142) -> proc_macro2::TokenStream {
143 let mut field_schema_parts = vec![];
144 let mut nested_type_collectors = vec![];
145
146 for field in fields.iter() {
148 let field_name = field.ident.as_ref().unwrap();
149 let field_name_str = field_name.to_string();
150 let attrs = parse_field_prompt_attrs(&field.attrs);
151
152 if field_name_str == "__type" {
156 continue;
157 }
158
159 if attrs.skip {
161 continue;
162 }
163
164 let field_docs = extract_doc_comments(&field.attrs);
166
167 let (is_vec, inner_type) = extract_vec_inner_type(&field.ty);
169
170 if is_vec {
171 let comment = if !field_docs.is_empty() {
174 format!(" // {}", field_docs)
175 } else {
176 String::new()
177 };
178
179 field_schema_parts.push(quote! {
180 {
181 let type_name = stringify!(#inner_type);
182 format!(" {}: {}[];{}", #field_name_str, type_name, #comment)
183 }
184 });
185
186 if let Some(inner) = inner_type
188 && !is_primitive_type(inner)
189 {
190 nested_type_collectors.push(quote! {
191 <#inner as #crate_path::prompt::ToPrompt>::prompt_schema()
192 });
193 }
194 } else {
195 let field_type = &field.ty;
197 let is_primitive = is_primitive_type(field_type);
198
199 if !is_primitive {
200 let comment = if !field_docs.is_empty() {
203 format!(" // {}", field_docs)
204 } else {
205 String::new()
206 };
207
208 field_schema_parts.push(quote! {
209 {
210 let type_name = stringify!(#field_type);
211 format!(" {}: {};{}", #field_name_str, type_name, #comment)
212 }
213 });
214
215 nested_type_collectors.push(quote! {
217 <#field_type as #crate_path::prompt::ToPrompt>::prompt_schema()
218 });
219 } else {
220 let type_str = format_type_for_schema(&field.ty);
223 let comment = if !field_docs.is_empty() {
224 format!(" // {}", field_docs)
225 } else {
226 String::new()
227 };
228
229 field_schema_parts.push(quote! {
230 format!(" {}: {};{}", #field_name_str, #type_str, #comment)
231 });
232 }
233 }
234 }
235
236 let mut header_lines = Vec::new();
251
252 if !struct_docs.is_empty() {
254 header_lines.push("/**".to_string());
255 header_lines.push(format!(" * {}", struct_docs));
256 header_lines.push(" */".to_string());
257 }
258
259 header_lines.push(format!("type {} = {{", struct_name));
261
262 quote! {
263 {
264 let mut all_lines: Vec<String> = Vec::new();
265
266 let nested_schemas: Vec<String> = vec![#(#nested_type_collectors),*];
268 let mut seen_types = std::collections::HashSet::<String>::new();
269
270 for schema in nested_schemas {
271 if !schema.is_empty() {
272 if seen_types.insert(schema.clone()) {
274 all_lines.push(schema);
275 all_lines.push(String::new()); }
277 }
278 }
279
280 let mut lines: Vec<String> = Vec::new();
282 #(lines.push(#header_lines.to_string());)*
283 #(lines.push(#field_schema_parts);)*
284 lines.push("}".to_string());
285 all_lines.push(lines.join("\n"));
286
287 vec![#crate_path::prompt::PromptPart::Text(all_lines.join("\n"))]
288 }
289 }
290}
291
292fn extract_vec_inner_type(ty: &syn::Type) -> (bool, Option<&syn::Type>) {
294 if let syn::Type::Path(type_path) = ty
295 && let Some(last_segment) = type_path.path.segments.last()
296 && last_segment.ident == "Vec"
297 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
298 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
299 {
300 return (true, Some(inner_type));
301 }
302 (false, None)
303}
304
305fn is_primitive_type(ty: &syn::Type) -> bool {
307 if let syn::Type::Path(type_path) = ty
308 && let Some(last_segment) = type_path.path.segments.last()
309 {
310 let type_name = last_segment.ident.to_string();
311 matches!(
312 type_name.as_str(),
313 "String"
314 | "str"
315 | "i8"
316 | "i16"
317 | "i32"
318 | "i64"
319 | "i128"
320 | "isize"
321 | "u8"
322 | "u16"
323 | "u32"
324 | "u64"
325 | "u128"
326 | "usize"
327 | "f32"
328 | "f64"
329 | "bool"
330 | "Vec"
331 | "Option"
332 | "HashMap"
333 | "BTreeMap"
334 | "HashSet"
335 | "BTreeSet"
336 )
337 } else {
338 true
340 }
341}
342
343fn format_type_for_schema(ty: &syn::Type) -> String {
345 match ty {
347 syn::Type::Path(type_path) => {
348 let path = &type_path.path;
349 if let Some(last_segment) = path.segments.last() {
350 let type_name = last_segment.ident.to_string();
351
352 if type_name == "Option"
354 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
355 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
356 {
357 return format!("{} | null", format_type_for_schema(inner_type));
358 }
359
360 match type_name.as_str() {
362 "String" | "str" => "string".to_string(),
363 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
364 | "u64" | "u128" | "usize" => "number".to_string(),
365 "f32" | "f64" => "number".to_string(),
366 "bool" => "boolean".to_string(),
367 "Vec" => {
368 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
369 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
370 {
371 return format!("{}[]", format_type_for_schema(inner_type));
372 }
373 "array".to_string()
374 }
375 _ => type_name.to_lowercase(),
376 }
377 } else {
378 "unknown".to_string()
379 }
380 }
381 _ => "unknown".to_string(),
382 }
383}
384
385enum PromptAttribute {
387 Skip,
388 Description(String),
389 None,
390}
391
392fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
394 for attr in attrs {
395 if attr.path().is_ident("prompt") {
396 if let Ok(meta_list) = attr.meta.require_list() {
398 let tokens = &meta_list.tokens;
399 let tokens_str = tokens.to_string();
400 if tokens_str == "skip" {
401 return PromptAttribute::Skip;
402 }
403 }
404
405 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
407 return PromptAttribute::Description(lit_str.value());
408 }
409 }
410 }
411 PromptAttribute::None
412}
413
414#[derive(Debug, Default)]
416struct FieldPromptAttrs {
417 skip: bool,
418 rename: Option<String>,
419 format_with: Option<String>,
420 image: bool,
421 example: Option<String>,
422}
423
424fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
426 let mut result = FieldPromptAttrs::default();
427
428 for attr in attrs {
429 if attr.path().is_ident("prompt") {
430 if let Ok(meta_list) = attr.meta.require_list() {
432 if let Ok(metas) =
434 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
435 {
436 for meta in metas {
437 match meta {
438 Meta::Path(path) if path.is_ident("skip") => {
439 result.skip = true;
440 }
441 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
442 if let syn::Expr::Lit(syn::ExprLit {
443 lit: syn::Lit::Str(lit_str),
444 ..
445 }) = nv.value
446 {
447 result.rename = Some(lit_str.value());
448 }
449 }
450 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
451 if let syn::Expr::Lit(syn::ExprLit {
452 lit: syn::Lit::Str(lit_str),
453 ..
454 }) = nv.value
455 {
456 result.format_with = Some(lit_str.value());
457 }
458 }
459 Meta::Path(path) if path.is_ident("image") => {
460 result.image = true;
461 }
462 Meta::NameValue(nv) if nv.path.is_ident("example") => {
463 if let syn::Expr::Lit(syn::ExprLit {
464 lit: syn::Lit::Str(lit_str),
465 ..
466 }) = nv.value
467 {
468 result.example = Some(lit_str.value());
469 }
470 }
471 _ => {}
472 }
473 }
474 } else if meta_list.tokens.to_string() == "skip" {
475 result.skip = true;
477 } else if meta_list.tokens.to_string() == "image" {
478 result.image = true;
480 }
481 }
482 }
483 }
484
485 result
486}
487
488#[proc_macro_derive(ToPrompt, attributes(prompt))]
531pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
532 let input = parse_macro_input!(input as DeriveInput);
533
534 let found_crate =
535 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
536 let crate_path = match found_crate {
537 FoundCrate::Itself => {
538 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
540 quote!(::#ident)
541 }
542 FoundCrate::Name(name) => {
543 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
544 quote!(::#ident)
545 }
546 };
547
548 match &input.data {
550 Data::Enum(data_enum) => {
551 let enum_name = &input.ident;
553 let enum_docs = extract_doc_comments(&input.attrs);
554
555 let mut variant_lines = Vec::new();
568 let mut first_variant_name = None;
569
570 for variant in &data_enum.variants {
571 let variant_name = &variant.ident;
572 let variant_name_str = variant_name.to_string();
573
574 match parse_prompt_attribute(&variant.attrs) {
575 PromptAttribute::Skip => continue,
576 PromptAttribute::Description(desc) => {
577 variant_lines.push(format!(" | \"{}\" // {}", variant_name_str, desc));
578 if first_variant_name.is_none() {
579 first_variant_name = Some(variant_name_str);
580 }
581 }
582 PromptAttribute::None => {
583 let docs = extract_doc_comments(&variant.attrs);
584 if !docs.is_empty() {
585 variant_lines
586 .push(format!(" | \"{}\" // {}", variant_name_str, docs));
587 } else {
588 variant_lines.push(format!(" | \"{}\"", variant_name_str));
589 }
590 if first_variant_name.is_none() {
591 first_variant_name = Some(variant_name_str);
592 }
593 }
594 }
595 }
596
597 let mut lines = Vec::new();
599
600 if !enum_docs.is_empty() {
602 lines.push("/**".to_string());
603 lines.push(format!(" * {}", enum_docs));
604 lines.push(" */".to_string());
605 }
606
607 lines.push(format!("type {} =", enum_name));
609
610 for line in &variant_lines {
612 lines.push(line.clone());
613 }
614
615 if let Some(last) = lines.last_mut()
617 && !last.ends_with(';')
618 {
619 last.push(';');
620 }
621
622 if let Some(first_name) = first_variant_name {
624 lines.push("".to_string()); lines.push(format!("Example value: \"{}\"", first_name));
626 }
627
628 let prompt_string = lines.join("\n");
629 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
630
631 let mut match_arms = Vec::new();
633 for variant in &data_enum.variants {
634 let variant_name = &variant.ident;
635
636 match parse_prompt_attribute(&variant.attrs) {
638 PromptAttribute::Skip => {
639 match_arms.push(quote! {
641 Self::#variant_name => stringify!(#variant_name).to_string()
642 });
643 }
644 PromptAttribute::Description(desc) => {
645 match_arms.push(quote! {
647 Self::#variant_name => format!("{}: {}", stringify!(#variant_name), #desc)
648 });
649 }
650 PromptAttribute::None => {
651 let variant_docs = extract_doc_comments(&variant.attrs);
653 if !variant_docs.is_empty() {
654 match_arms.push(quote! {
655 Self::#variant_name => format!("{}: {}", stringify!(#variant_name), #variant_docs)
656 });
657 } else {
658 match_arms.push(quote! {
659 Self::#variant_name => stringify!(#variant_name).to_string()
660 });
661 }
662 }
663 }
664 }
665
666 let to_prompt_impl = if match_arms.is_empty() {
667 quote! {
669 fn to_prompt(&self) -> String {
670 match *self {}
671 }
672 }
673 } else {
674 quote! {
675 fn to_prompt(&self) -> String {
676 match self {
677 #(#match_arms),*
678 }
679 }
680 }
681 };
682
683 let expanded = quote! {
684 impl #impl_generics #crate_path::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
685 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
686 vec![#crate_path::prompt::PromptPart::Text(self.to_prompt())]
687 }
688
689 #to_prompt_impl
690
691 fn prompt_schema() -> String {
692 #prompt_string.to_string()
693 }
694 }
695 };
696
697 TokenStream::from(expanded)
698 }
699 Data::Struct(data_struct) => {
700 let mut template_attr = None;
702 let mut template_file_attr = None;
703 let mut mode_attr = None;
704 let mut validate_attr = false;
705 let mut type_marker_attr = false;
706
707 for attr in &input.attrs {
708 if attr.path().is_ident("prompt") {
709 if let Ok(metas) =
711 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
712 {
713 for meta in metas {
714 match meta {
715 Meta::NameValue(nv) if nv.path.is_ident("template") => {
716 if let syn::Expr::Lit(expr_lit) = nv.value
717 && let syn::Lit::Str(lit_str) = expr_lit.lit
718 {
719 template_attr = Some(lit_str.value());
720 }
721 }
722 Meta::NameValue(nv) if nv.path.is_ident("template_file") => {
723 if let syn::Expr::Lit(expr_lit) = nv.value
724 && let syn::Lit::Str(lit_str) = expr_lit.lit
725 {
726 template_file_attr = Some(lit_str.value());
727 }
728 }
729 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
730 if let syn::Expr::Lit(expr_lit) = nv.value
731 && let syn::Lit::Str(lit_str) = expr_lit.lit
732 {
733 mode_attr = Some(lit_str.value());
734 }
735 }
736 Meta::NameValue(nv) if nv.path.is_ident("validate") => {
737 if let syn::Expr::Lit(expr_lit) = nv.value
738 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
739 {
740 validate_attr = lit_bool.value();
741 }
742 }
743 Meta::NameValue(nv) if nv.path.is_ident("type_marker") => {
744 if let syn::Expr::Lit(expr_lit) = nv.value
745 && let syn::Lit::Bool(lit_bool) = expr_lit.lit
746 {
747 type_marker_attr = lit_bool.value();
748 }
749 }
750 Meta::Path(path) if path.is_ident("type_marker") => {
751 type_marker_attr = true;
753 }
754 _ => {}
755 }
756 }
757 }
758 }
759 }
760
761 if template_attr.is_some() && template_file_attr.is_some() {
763 return syn::Error::new(
764 input.ident.span(),
765 "The `template` and `template_file` attributes are mutually exclusive. Please use only one.",
766 ).to_compile_error().into();
767 }
768
769 let template_str = if let Some(file_path) = template_file_attr {
771 let mut full_path = None;
775
776 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
778 let is_trybuild = manifest_dir.contains("target/tests/trybuild");
780
781 if !is_trybuild {
782 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
784 if candidate.exists() {
785 full_path = Some(candidate);
786 }
787 } else {
788 if let Some(target_pos) = manifest_dir.find("/target/tests/trybuild") {
794 let workspace_root = &manifest_dir[..target_pos];
795 let original_macros_dir = std::path::Path::new(workspace_root)
797 .join("crates")
798 .join("llm-toolkit-macros");
799
800 let candidate = original_macros_dir.join(&file_path);
801 if candidate.exists() {
802 full_path = Some(candidate);
803 }
804 }
805 }
806 }
807
808 if full_path.is_none() {
810 let candidate = std::path::Path::new(&file_path).to_path_buf();
811 if candidate.exists() {
812 full_path = Some(candidate);
813 }
814 }
815
816 if full_path.is_none()
819 && let Ok(current_dir) = std::env::current_dir()
820 {
821 let mut search_dir = current_dir.as_path();
822 for _ in 0..10 {
824 let macros_dir = search_dir.join("crates/llm-toolkit-macros");
826 if macros_dir.exists() {
827 let candidate = macros_dir.join(&file_path);
828 if candidate.exists() {
829 full_path = Some(candidate);
830 break;
831 }
832 }
833 let candidate = search_dir.join(&file_path);
835 if candidate.exists() {
836 full_path = Some(candidate);
837 break;
838 }
839 if let Some(parent) = search_dir.parent() {
840 search_dir = parent;
841 } else {
842 break;
843 }
844 }
845 }
846
847 if full_path.is_none() {
849 let mut error_msg = format!(
851 "Template file '{}' not found at compile time.\n\nSearched in:",
852 file_path
853 );
854
855 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
856 let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
857 error_msg.push_str(&format!("\n - {}", candidate.display()));
858 }
859
860 if let Ok(current_dir) = std::env::current_dir() {
861 let candidate = current_dir.join(&file_path);
862 error_msg.push_str(&format!("\n - {}", candidate.display()));
863 }
864
865 error_msg.push_str("\n\nPlease ensure:");
866 error_msg.push_str("\n 1. The template file exists");
867 error_msg.push_str("\n 2. The path is relative to CARGO_MANIFEST_DIR");
868 error_msg.push_str("\n 3. There are no typos in the path");
869
870 return syn::Error::new(input.ident.span(), error_msg)
871 .to_compile_error()
872 .into();
873 }
874
875 let final_path = full_path.unwrap();
876
877 match std::fs::read_to_string(&final_path) {
879 Ok(content) => Some(content),
880 Err(e) => {
881 return syn::Error::new(
882 input.ident.span(),
883 format!(
884 "Failed to read template file '{}': {}\n\nPath resolved to: {}",
885 file_path,
886 e,
887 final_path.display()
888 ),
889 )
890 .to_compile_error()
891 .into();
892 }
893 }
894 } else {
895 template_attr
896 };
897
898 if validate_attr && let Some(template) = &template_str {
900 let mut env = minijinja::Environment::new();
902 if let Err(e) = env.add_template("validation", template) {
903 let warning_msg =
905 format!("Template validation warning: Invalid Jinja syntax - {}", e);
906 let warning_ident = syn::Ident::new(
907 "TEMPLATE_VALIDATION_WARNING",
908 proc_macro2::Span::call_site(),
909 );
910 let _warning_tokens = quote! {
911 #[deprecated(note = #warning_msg)]
912 const #warning_ident: () = ();
913 let _ = #warning_ident;
914 };
915 eprintln!("cargo:warning={}", warning_msg);
917 }
918
919 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
921 &fields.named
922 } else {
923 panic!("Template validation is only supported for structs with named fields.");
924 };
925
926 let field_names: std::collections::HashSet<String> = fields
927 .iter()
928 .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
929 .collect();
930
931 let placeholders = parse_template_placeholders_with_mode(template);
933
934 for (placeholder_name, _mode) in &placeholders {
935 if placeholder_name != "self" && !field_names.contains(placeholder_name) {
936 let warning_msg = format!(
937 "Template validation warning: Variable '{}' used in template but not found in struct fields",
938 placeholder_name
939 );
940 eprintln!("cargo:warning={}", warning_msg);
941 }
942 }
943 }
944
945 let name = input.ident;
946 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
947
948 let struct_docs = extract_doc_comments(&input.attrs);
950
951 let is_mode_based =
953 mode_attr.is_some() || (template_str.is_none() && struct_docs.contains("mode"));
954
955 let expanded = if is_mode_based || mode_attr.is_some() {
956 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
958 &fields.named
959 } else {
960 panic!(
961 "Mode-based prompt generation is only supported for structs with named fields."
962 );
963 };
964
965 let struct_name_str = name.to_string();
966
967 let has_default = input.attrs.iter().any(|attr| {
969 if attr.path().is_ident("derive")
970 && let Ok(meta_list) = attr.meta.require_list()
971 {
972 let tokens_str = meta_list.tokens.to_string();
973 tokens_str.contains("Default")
974 } else {
975 false
976 }
977 });
978
979 let schema_parts = generate_schema_only_parts(
990 &struct_name_str,
991 &struct_docs,
992 fields,
993 &crate_path,
994 type_marker_attr,
995 );
996
997 let example_parts = generate_example_only_parts(fields, has_default, &crate_path);
999
1000 quote! {
1001 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1002 fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<#crate_path::prompt::PromptPart> {
1003 match mode {
1004 "schema_only" => #schema_parts,
1005 "example_only" => #example_parts,
1006 "full" | _ => {
1007 let mut parts = Vec::new();
1009
1010 let schema_parts = #schema_parts;
1012 parts.extend(schema_parts);
1013
1014 parts.push(#crate_path::prompt::PromptPart::Text("\n### Example".to_string()));
1016 parts.push(#crate_path::prompt::PromptPart::Text(
1017 format!("Here is an example of a valid `{}` object:", #struct_name_str)
1018 ));
1019
1020 let example_parts = #example_parts;
1022 parts.extend(example_parts);
1023
1024 parts
1025 }
1026 }
1027 }
1028
1029 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1030 self.to_prompt_parts_with_mode("full")
1031 }
1032
1033 fn to_prompt(&self) -> String {
1034 self.to_prompt_parts()
1035 .into_iter()
1036 .filter_map(|part| match part {
1037 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1038 _ => None,
1039 })
1040 .collect::<Vec<_>>()
1041 .join("\n")
1042 }
1043
1044 fn prompt_schema() -> String {
1045 use std::sync::OnceLock;
1046 static SCHEMA_CACHE: OnceLock<String> = OnceLock::new();
1047
1048 SCHEMA_CACHE.get_or_init(|| {
1049 let schema_parts = #schema_parts;
1050 schema_parts
1051 .into_iter()
1052 .filter_map(|part| match part {
1053 #crate_path::prompt::PromptPart::Text(text) => Some(text),
1054 _ => None,
1055 })
1056 .collect::<Vec<_>>()
1057 .join("\n")
1058 }).clone()
1059 }
1060 }
1061 }
1062 } else if let Some(template) = template_str {
1063 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1066 &fields.named
1067 } else {
1068 panic!(
1069 "Template prompt generation is only supported for structs with named fields."
1070 );
1071 };
1072
1073 let placeholders = parse_template_placeholders_with_mode(&template);
1075 let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
1077 mode.is_some()
1078 && fields
1079 .iter()
1080 .any(|f| f.ident.as_ref().unwrap() == field_name)
1081 });
1082
1083 let mut image_field_parts = Vec::new();
1084 for f in fields.iter() {
1085 let field_name = f.ident.as_ref().unwrap();
1086 let attrs = parse_field_prompt_attrs(&f.attrs);
1087
1088 if attrs.image {
1089 image_field_parts.push(quote! {
1091 parts.extend(self.#field_name.to_prompt_parts());
1092 });
1093 }
1094 }
1095
1096 if has_mode_syntax {
1098 let mut context_fields = Vec::new();
1100 let mut modified_template = template.clone();
1101
1102 for (field_name, mode_opt) in &placeholders {
1104 if let Some(mode) = mode_opt {
1105 let unique_key = format!("{}__{}", field_name, mode);
1107
1108 let pattern = format!("{{{{ {}:{} }}}}", field_name, mode);
1110 let replacement = format!("{{{{ {} }}}}", unique_key);
1111 modified_template = modified_template.replace(&pattern, &replacement);
1112
1113 let field_ident =
1115 syn::Ident::new(field_name, proc_macro2::Span::call_site());
1116
1117 context_fields.push(quote! {
1119 context.insert(
1120 #unique_key.to_string(),
1121 minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
1122 );
1123 });
1124 }
1125 }
1126
1127 for field in fields.iter() {
1129 let field_name = field.ident.as_ref().unwrap();
1130 let field_name_str = field_name.to_string();
1131
1132 let has_mode_entry = placeholders
1134 .iter()
1135 .any(|(name, mode)| name == &field_name_str && mode.is_some());
1136
1137 if !has_mode_entry {
1138 let is_primitive = match &field.ty {
1141 syn::Type::Path(type_path) => {
1142 if let Some(segment) = type_path.path.segments.last() {
1143 let type_name = segment.ident.to_string();
1144 matches!(
1145 type_name.as_str(),
1146 "String"
1147 | "str"
1148 | "i8"
1149 | "i16"
1150 | "i32"
1151 | "i64"
1152 | "i128"
1153 | "isize"
1154 | "u8"
1155 | "u16"
1156 | "u32"
1157 | "u64"
1158 | "u128"
1159 | "usize"
1160 | "f32"
1161 | "f64"
1162 | "bool"
1163 | "char"
1164 )
1165 } else {
1166 false
1167 }
1168 }
1169 _ => false,
1170 };
1171
1172 if is_primitive {
1173 context_fields.push(quote! {
1174 context.insert(
1175 #field_name_str.to_string(),
1176 minijinja::Value::from_serialize(&self.#field_name)
1177 );
1178 });
1179 } else {
1180 context_fields.push(quote! {
1182 context.insert(
1183 #field_name_str.to_string(),
1184 minijinja::Value::from(self.#field_name.to_prompt())
1185 );
1186 });
1187 }
1188 }
1189 }
1190
1191 quote! {
1192 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1193 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1194 let mut parts = Vec::new();
1195
1196 #(#image_field_parts)*
1198
1199 let text = {
1201 let mut env = minijinja::Environment::new();
1202 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1203 panic!("Failed to parse template: {}", e)
1204 });
1205
1206 let tmpl = env.get_template("prompt").unwrap();
1207
1208 let mut context = std::collections::HashMap::new();
1209 #(#context_fields)*
1210
1211 tmpl.render(context).unwrap_or_else(|e| {
1212 format!("Failed to render prompt: {}", e)
1213 })
1214 };
1215
1216 if !text.is_empty() {
1217 parts.push(#crate_path::prompt::PromptPart::Text(text));
1218 }
1219
1220 parts
1221 }
1222
1223 fn to_prompt(&self) -> String {
1224 let mut env = minijinja::Environment::new();
1226 env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
1227 panic!("Failed to parse template: {}", e)
1228 });
1229
1230 let tmpl = env.get_template("prompt").unwrap();
1231
1232 let mut context = std::collections::HashMap::new();
1233 #(#context_fields)*
1234
1235 tmpl.render(context).unwrap_or_else(|e| {
1236 format!("Failed to render prompt: {}", e)
1237 })
1238 }
1239
1240 fn prompt_schema() -> String {
1241 String::new() }
1243 }
1244 }
1245 } else {
1246 quote! {
1248 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1249 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1250 let mut parts = Vec::new();
1251
1252 #(#image_field_parts)*
1254
1255 let text = #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1257 format!("Failed to render prompt: {}", e)
1258 });
1259 if !text.is_empty() {
1260 parts.push(#crate_path::prompt::PromptPart::Text(text));
1261 }
1262
1263 parts
1264 }
1265
1266 fn to_prompt(&self) -> String {
1267 #crate_path::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
1268 format!("Failed to render prompt: {}", e)
1269 })
1270 }
1271
1272 fn prompt_schema() -> String {
1273 String::new() }
1275 }
1276 }
1277 }
1278 } else {
1279 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
1282 &fields.named
1283 } else {
1284 panic!(
1285 "Default prompt generation is only supported for structs with named fields."
1286 );
1287 };
1288
1289 let mut text_field_parts = Vec::new();
1291 let mut image_field_parts = Vec::new();
1292
1293 for f in fields.iter() {
1294 let field_name = f.ident.as_ref().unwrap();
1295 let attrs = parse_field_prompt_attrs(&f.attrs);
1296
1297 if attrs.skip {
1299 continue;
1300 }
1301
1302 if attrs.image {
1303 image_field_parts.push(quote! {
1305 parts.extend(self.#field_name.to_prompt_parts());
1306 });
1307 } else {
1308 let key = if let Some(rename) = attrs.rename {
1314 rename
1315 } else {
1316 let doc_comment = extract_doc_comments(&f.attrs);
1317 if !doc_comment.is_empty() {
1318 doc_comment
1319 } else {
1320 field_name.to_string()
1321 }
1322 };
1323
1324 let value_expr = if let Some(format_with) = attrs.format_with {
1326 let func_path: syn::Path =
1328 syn::parse_str(&format_with).unwrap_or_else(|_| {
1329 panic!("Invalid function path: {}", format_with)
1330 });
1331 quote! { #func_path(&self.#field_name) }
1332 } else {
1333 quote! { self.#field_name.to_prompt() }
1334 };
1335
1336 text_field_parts.push(quote! {
1337 text_parts.push(format!("{}: {}", #key, #value_expr));
1338 });
1339 }
1340 }
1341
1342 quote! {
1344 impl #impl_generics #crate_path::prompt::ToPrompt for #name #ty_generics #where_clause {
1345 fn to_prompt_parts(&self) -> Vec<#crate_path::prompt::PromptPart> {
1346 let mut parts = Vec::new();
1347
1348 #(#image_field_parts)*
1350
1351 let mut text_parts = Vec::new();
1353 #(#text_field_parts)*
1354
1355 if !text_parts.is_empty() {
1356 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1357 }
1358
1359 parts
1360 }
1361
1362 fn to_prompt(&self) -> String {
1363 let mut text_parts = Vec::new();
1364 #(#text_field_parts)*
1365 text_parts.join("\n")
1366 }
1367
1368 fn prompt_schema() -> String {
1369 String::new() }
1371 }
1372 }
1373 };
1374
1375 TokenStream::from(expanded)
1376 }
1377 Data::Union(_) => {
1378 panic!("`#[derive(ToPrompt)]` is not supported for unions");
1379 }
1380 }
1381}
1382
1383#[derive(Debug, Clone)]
1385struct TargetInfo {
1386 name: String,
1387 template: Option<String>,
1388 field_configs: std::collections::HashMap<String, FieldTargetConfig>,
1389}
1390
1391#[derive(Debug, Clone, Default)]
1393struct FieldTargetConfig {
1394 skip: bool,
1395 rename: Option<String>,
1396 format_with: Option<String>,
1397 image: bool,
1398 include_only: bool, }
1400
1401fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
1403 let mut configs = Vec::new();
1404
1405 for attr in attrs {
1406 if attr.path().is_ident("prompt_for")
1407 && let Ok(meta_list) = attr.meta.require_list()
1408 {
1409 if meta_list.tokens.to_string() == "skip" {
1411 let config = FieldTargetConfig {
1413 skip: true,
1414 ..Default::default()
1415 };
1416 configs.push(("*".to_string(), config));
1417 } else if let Ok(metas) =
1418 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1419 {
1420 let mut target_name = None;
1421 let mut config = FieldTargetConfig::default();
1422
1423 for meta in metas {
1424 match meta {
1425 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1426 if let syn::Expr::Lit(syn::ExprLit {
1427 lit: syn::Lit::Str(lit_str),
1428 ..
1429 }) = nv.value
1430 {
1431 target_name = Some(lit_str.value());
1432 }
1433 }
1434 Meta::Path(path) if path.is_ident("skip") => {
1435 config.skip = true;
1436 }
1437 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
1438 if let syn::Expr::Lit(syn::ExprLit {
1439 lit: syn::Lit::Str(lit_str),
1440 ..
1441 }) = nv.value
1442 {
1443 config.rename = Some(lit_str.value());
1444 }
1445 }
1446 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
1447 if let syn::Expr::Lit(syn::ExprLit {
1448 lit: syn::Lit::Str(lit_str),
1449 ..
1450 }) = nv.value
1451 {
1452 config.format_with = Some(lit_str.value());
1453 }
1454 }
1455 Meta::Path(path) if path.is_ident("image") => {
1456 config.image = true;
1457 }
1458 _ => {}
1459 }
1460 }
1461
1462 if let Some(name) = target_name {
1463 config.include_only = true;
1464 configs.push((name, config));
1465 }
1466 }
1467 }
1468 }
1469
1470 configs
1471}
1472
1473fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
1475 let mut targets = Vec::new();
1476
1477 for attr in attrs {
1478 if attr.path().is_ident("prompt_for")
1479 && let Ok(meta_list) = attr.meta.require_list()
1480 && let Ok(metas) =
1481 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1482 {
1483 let mut target_name = None;
1484 let mut template = None;
1485
1486 for meta in metas {
1487 match meta {
1488 Meta::NameValue(nv) if nv.path.is_ident("name") => {
1489 if let syn::Expr::Lit(syn::ExprLit {
1490 lit: syn::Lit::Str(lit_str),
1491 ..
1492 }) = nv.value
1493 {
1494 target_name = Some(lit_str.value());
1495 }
1496 }
1497 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1498 if let syn::Expr::Lit(syn::ExprLit {
1499 lit: syn::Lit::Str(lit_str),
1500 ..
1501 }) = nv.value
1502 {
1503 template = Some(lit_str.value());
1504 }
1505 }
1506 _ => {}
1507 }
1508 }
1509
1510 if let Some(name) = target_name {
1511 targets.push(TargetInfo {
1512 name,
1513 template,
1514 field_configs: std::collections::HashMap::new(),
1515 });
1516 }
1517 }
1518 }
1519
1520 targets
1521}
1522
1523#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
1524pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
1525 let input = parse_macro_input!(input as DeriveInput);
1526
1527 let found_crate =
1528 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1529 let crate_path = match found_crate {
1530 FoundCrate::Itself => {
1531 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1533 quote!(::#ident)
1534 }
1535 FoundCrate::Name(name) => {
1536 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1537 quote!(::#ident)
1538 }
1539 };
1540
1541 let data_struct = match &input.data {
1543 Data::Struct(data) => data,
1544 _ => {
1545 return syn::Error::new(
1546 input.ident.span(),
1547 "`#[derive(ToPromptSet)]` is only supported for structs",
1548 )
1549 .to_compile_error()
1550 .into();
1551 }
1552 };
1553
1554 let fields = match &data_struct.fields {
1555 syn::Fields::Named(fields) => &fields.named,
1556 _ => {
1557 return syn::Error::new(
1558 input.ident.span(),
1559 "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
1560 )
1561 .to_compile_error()
1562 .into();
1563 }
1564 };
1565
1566 let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
1568
1569 for field in fields.iter() {
1571 let field_name = field.ident.as_ref().unwrap().to_string();
1572 let field_configs = parse_prompt_for_attrs(&field.attrs);
1573
1574 for (target_name, config) in field_configs {
1575 if target_name == "*" {
1576 for target in &mut targets {
1578 target
1579 .field_configs
1580 .entry(field_name.clone())
1581 .or_insert_with(FieldTargetConfig::default)
1582 .skip = config.skip;
1583 }
1584 } else {
1585 let target_exists = targets.iter().any(|t| t.name == target_name);
1587 if !target_exists {
1588 targets.push(TargetInfo {
1590 name: target_name.clone(),
1591 template: None,
1592 field_configs: std::collections::HashMap::new(),
1593 });
1594 }
1595
1596 let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
1597
1598 target.field_configs.insert(field_name.clone(), config);
1599 }
1600 }
1601 }
1602
1603 let mut match_arms = Vec::new();
1605
1606 for target in &targets {
1607 let target_name = &target.name;
1608
1609 if let Some(template_str) = &target.template {
1610 let mut image_parts = Vec::new();
1612
1613 for field in fields.iter() {
1614 let field_name = field.ident.as_ref().unwrap();
1615 let field_name_str = field_name.to_string();
1616
1617 if let Some(config) = target.field_configs.get(&field_name_str)
1618 && config.image
1619 {
1620 image_parts.push(quote! {
1621 parts.extend(self.#field_name.to_prompt_parts());
1622 });
1623 }
1624 }
1625
1626 match_arms.push(quote! {
1627 #target_name => {
1628 let mut parts = Vec::new();
1629
1630 #(#image_parts)*
1631
1632 let text = #crate_path::prompt::render_prompt(#template_str, self)
1633 .map_err(|e| #crate_path::prompt::PromptSetError::RenderFailed {
1634 target: #target_name.to_string(),
1635 source: e,
1636 })?;
1637
1638 if !text.is_empty() {
1639 parts.push(#crate_path::prompt::PromptPart::Text(text));
1640 }
1641
1642 Ok(parts)
1643 }
1644 });
1645 } else {
1646 let mut text_field_parts = Vec::new();
1648 let mut image_field_parts = Vec::new();
1649
1650 for field in fields.iter() {
1651 let field_name = field.ident.as_ref().unwrap();
1652 let field_name_str = field_name.to_string();
1653
1654 let config = target.field_configs.get(&field_name_str);
1656
1657 if let Some(cfg) = config
1659 && cfg.skip
1660 {
1661 continue;
1662 }
1663
1664 let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
1668 let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
1669 .iter()
1670 .any(|(name, _)| name != "*");
1671
1672 if has_any_target_specific_config && !is_explicitly_for_this_target {
1673 continue;
1674 }
1675
1676 if let Some(cfg) = config {
1677 if cfg.image {
1678 image_field_parts.push(quote! {
1679 parts.extend(self.#field_name.to_prompt_parts());
1680 });
1681 } else {
1682 let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
1683
1684 let value_expr = if let Some(format_with) = &cfg.format_with {
1685 match syn::parse_str::<syn::Path>(format_with) {
1687 Ok(func_path) => quote! { #func_path(&self.#field_name) },
1688 Err(_) => {
1689 let error_msg = format!(
1691 "Invalid function path in format_with: '{}'",
1692 format_with
1693 );
1694 quote! {
1695 compile_error!(#error_msg);
1696 String::new()
1697 }
1698 }
1699 }
1700 } else {
1701 quote! { self.#field_name.to_prompt() }
1702 };
1703
1704 text_field_parts.push(quote! {
1705 text_parts.push(format!("{}: {}", #key, #value_expr));
1706 });
1707 }
1708 } else {
1709 text_field_parts.push(quote! {
1711 text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
1712 });
1713 }
1714 }
1715
1716 match_arms.push(quote! {
1717 #target_name => {
1718 let mut parts = Vec::new();
1719
1720 #(#image_field_parts)*
1721
1722 let mut text_parts = Vec::new();
1723 #(#text_field_parts)*
1724
1725 if !text_parts.is_empty() {
1726 parts.push(#crate_path::prompt::PromptPart::Text(text_parts.join("\n")));
1727 }
1728
1729 Ok(parts)
1730 }
1731 });
1732 }
1733 }
1734
1735 let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
1737
1738 match_arms.push(quote! {
1740 _ => {
1741 let available = vec![#(#target_names.to_string()),*];
1742 Err(#crate_path::prompt::PromptSetError::TargetNotFound {
1743 target: target.to_string(),
1744 available,
1745 })
1746 }
1747 });
1748
1749 let struct_name = &input.ident;
1750 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1751
1752 let expanded = quote! {
1753 impl #impl_generics #crate_path::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
1754 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<#crate_path::prompt::PromptPart>, #crate_path::prompt::PromptSetError> {
1755 match target {
1756 #(#match_arms)*
1757 }
1758 }
1759 }
1760 };
1761
1762 TokenStream::from(expanded)
1763}
1764
1765struct TypeList {
1767 types: Punctuated<syn::Type, Token![,]>,
1768}
1769
1770impl Parse for TypeList {
1771 fn parse(input: ParseStream) -> syn::Result<Self> {
1772 Ok(TypeList {
1773 types: Punctuated::parse_terminated(input)?,
1774 })
1775 }
1776}
1777
1778#[proc_macro]
1802pub fn examples_section(input: TokenStream) -> TokenStream {
1803 let input = parse_macro_input!(input as TypeList);
1804
1805 let found_crate =
1806 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1807 let _crate_path = match found_crate {
1808 FoundCrate::Itself => quote!(crate),
1809 FoundCrate::Name(name) => {
1810 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1811 quote!(::#ident)
1812 }
1813 };
1814
1815 let mut type_sections = Vec::new();
1817
1818 for ty in input.types.iter() {
1819 let type_name_str = quote!(#ty).to_string();
1821
1822 type_sections.push(quote! {
1824 {
1825 let type_name = #type_name_str;
1826 let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
1827 format!("---\n#### `{}`\n{}", type_name, json_example)
1828 }
1829 });
1830 }
1831
1832 let expanded = quote! {
1834 {
1835 let mut sections = Vec::new();
1836 sections.push("---".to_string());
1837 sections.push("### Examples".to_string());
1838 sections.push("".to_string());
1839 sections.push("Here are examples of the data structures you should use.".to_string());
1840 sections.push("".to_string());
1841
1842 #(sections.push(#type_sections);)*
1843
1844 sections.push("---".to_string());
1845
1846 sections.join("\n")
1847 }
1848 };
1849
1850 TokenStream::from(expanded)
1851}
1852
1853fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
1855 for attr in attrs {
1856 if attr.path().is_ident("prompt_for")
1857 && let Ok(meta_list) = attr.meta.require_list()
1858 && let Ok(metas) =
1859 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1860 {
1861 let mut target_type = None;
1862 let mut template = None;
1863
1864 for meta in metas {
1865 match meta {
1866 Meta::NameValue(nv) if nv.path.is_ident("target") => {
1867 if let syn::Expr::Lit(syn::ExprLit {
1868 lit: syn::Lit::Str(lit_str),
1869 ..
1870 }) = nv.value
1871 {
1872 target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
1874 }
1875 }
1876 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1877 if let syn::Expr::Lit(syn::ExprLit {
1878 lit: syn::Lit::Str(lit_str),
1879 ..
1880 }) = nv.value
1881 {
1882 template = Some(lit_str.value());
1883 }
1884 }
1885 _ => {}
1886 }
1887 }
1888
1889 if let (Some(target), Some(tmpl)) = (target_type, template) {
1890 return (target, tmpl);
1891 }
1892 }
1893 }
1894
1895 panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
1896}
1897
1898#[proc_macro_attribute]
1932pub fn define_intent(_attr: TokenStream, item: TokenStream) -> TokenStream {
1933 let input = parse_macro_input!(item as DeriveInput);
1934
1935 let found_crate =
1936 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
1937 let crate_path = match found_crate {
1938 FoundCrate::Itself => {
1939 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
1941 quote!(::#ident)
1942 }
1943 FoundCrate::Name(name) => {
1944 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1945 quote!(::#ident)
1946 }
1947 };
1948
1949 let enum_data = match &input.data {
1951 Data::Enum(data) => data,
1952 _ => {
1953 return syn::Error::new(
1954 input.ident.span(),
1955 "`#[define_intent]` can only be applied to enums",
1956 )
1957 .to_compile_error()
1958 .into();
1959 }
1960 };
1961
1962 let mut prompt_template = None;
1964 let mut extractor_tag = None;
1965 let mut mode = None;
1966
1967 for attr in &input.attrs {
1968 if attr.path().is_ident("intent")
1969 && let Ok(metas) =
1970 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1971 {
1972 for meta in metas {
1973 match meta {
1974 Meta::NameValue(nv) if nv.path.is_ident("prompt") => {
1975 if let syn::Expr::Lit(syn::ExprLit {
1976 lit: syn::Lit::Str(lit_str),
1977 ..
1978 }) = nv.value
1979 {
1980 prompt_template = Some(lit_str.value());
1981 }
1982 }
1983 Meta::NameValue(nv) if nv.path.is_ident("extractor_tag") => {
1984 if let syn::Expr::Lit(syn::ExprLit {
1985 lit: syn::Lit::Str(lit_str),
1986 ..
1987 }) = nv.value
1988 {
1989 extractor_tag = Some(lit_str.value());
1990 }
1991 }
1992 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
1993 if let syn::Expr::Lit(syn::ExprLit {
1994 lit: syn::Lit::Str(lit_str),
1995 ..
1996 }) = nv.value
1997 {
1998 mode = Some(lit_str.value());
1999 }
2000 }
2001 _ => {}
2002 }
2003 }
2004 }
2005 }
2006
2007 let mode = mode.unwrap_or_else(|| "single".to_string());
2009
2010 if mode != "single" && mode != "multi_tag" {
2012 return syn::Error::new(
2013 input.ident.span(),
2014 "`mode` must be either \"single\" or \"multi_tag\"",
2015 )
2016 .to_compile_error()
2017 .into();
2018 }
2019
2020 let prompt_template = match prompt_template {
2022 Some(p) => p,
2023 None => {
2024 return syn::Error::new(
2025 input.ident.span(),
2026 "`#[intent(...)]` attribute must include `prompt = \"...\"`",
2027 )
2028 .to_compile_error()
2029 .into();
2030 }
2031 };
2032
2033 if mode == "multi_tag" {
2035 let enum_name = &input.ident;
2036 let actions_doc = generate_multi_tag_actions_doc(&enum_data.variants);
2037 return generate_multi_tag_output(
2038 &input,
2039 enum_name,
2040 enum_data,
2041 prompt_template,
2042 actions_doc,
2043 );
2044 }
2045
2046 let extractor_tag = match extractor_tag {
2048 Some(t) => t,
2049 None => {
2050 return syn::Error::new(
2051 input.ident.span(),
2052 "`#[intent(...)]` attribute must include `extractor_tag = \"...\"`",
2053 )
2054 .to_compile_error()
2055 .into();
2056 }
2057 };
2058
2059 let enum_name = &input.ident;
2061 let enum_docs = extract_doc_comments(&input.attrs);
2062
2063 let mut intents_doc_lines = Vec::new();
2064
2065 if !enum_docs.is_empty() {
2067 intents_doc_lines.push(format!("{}: {}", enum_name, enum_docs));
2068 } else {
2069 intents_doc_lines.push(format!("{}:", enum_name));
2070 }
2071 intents_doc_lines.push(String::new()); intents_doc_lines.push("Possible values:".to_string());
2073
2074 for variant in &enum_data.variants {
2076 let variant_name = &variant.ident;
2077 let variant_docs = extract_doc_comments(&variant.attrs);
2078
2079 if !variant_docs.is_empty() {
2080 intents_doc_lines.push(format!("- {}: {}", variant_name, variant_docs));
2081 } else {
2082 intents_doc_lines.push(format!("- {}", variant_name));
2083 }
2084 }
2085
2086 let intents_doc_str = intents_doc_lines.join("\n");
2087
2088 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2090 let user_variables: Vec<String> = placeholders
2091 .iter()
2092 .filter_map(|(name, _)| {
2093 if name != "intents_doc" {
2094 Some(name.clone())
2095 } else {
2096 None
2097 }
2098 })
2099 .collect();
2100
2101 let enum_name_str = enum_name.to_string();
2103 let snake_case_name = to_snake_case(&enum_name_str);
2104 let function_name = syn::Ident::new(
2105 &format!("build_{}_prompt", snake_case_name),
2106 proc_macro2::Span::call_site(),
2107 );
2108
2109 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2111 .iter()
2112 .map(|var| {
2113 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2114 quote! { #ident: &str }
2115 })
2116 .collect();
2117
2118 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2120 .iter()
2121 .map(|var| {
2122 let var_str = var.clone();
2123 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2124 quote! {
2125 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2126 }
2127 })
2128 .collect();
2129
2130 let converted_template = prompt_template.clone();
2132
2133 let extractor_name = syn::Ident::new(
2135 &format!("{}Extractor", enum_name),
2136 proc_macro2::Span::call_site(),
2137 );
2138
2139 let filtered_attrs: Vec<_> = input
2141 .attrs
2142 .iter()
2143 .filter(|attr| !attr.path().is_ident("intent"))
2144 .collect();
2145
2146 let vis = &input.vis;
2148 let generics = &input.generics;
2149 let variants = &enum_data.variants;
2150 let enum_output = quote! {
2151 #(#filtered_attrs)*
2152 #vis enum #enum_name #generics {
2153 #variants
2154 }
2155 };
2156
2157 let expanded = quote! {
2159 #enum_output
2161
2162 pub fn #function_name(#(#function_params),*) -> String {
2164 let mut env = minijinja::Environment::new();
2165 env.add_template("prompt", #converted_template)
2166 .expect("Failed to parse intent prompt template");
2167
2168 let tmpl = env.get_template("prompt").unwrap();
2169
2170 let mut __template_context = std::collections::HashMap::new();
2171
2172 __template_context.insert("intents_doc".to_string(), minijinja::Value::from(#intents_doc_str));
2174
2175 #(#context_insertions)*
2177
2178 tmpl.render(&__template_context)
2179 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2180 }
2181
2182 pub struct #extractor_name;
2184
2185 impl #extractor_name {
2186 pub const EXTRACTOR_TAG: &'static str = #extractor_tag;
2187 }
2188
2189 impl #crate_path::intent::IntentExtractor<#enum_name> for #extractor_name {
2190 fn extract_intent(&self, response: &str) -> Result<#enum_name, #crate_path::intent::IntentExtractionError> {
2191 #crate_path::intent::extract_intent_from_response(response, Self::EXTRACTOR_TAG)
2193 }
2194 }
2195 };
2196
2197 TokenStream::from(expanded)
2198}
2199
2200fn to_snake_case(s: &str) -> String {
2202 let mut result = String::new();
2203 let mut prev_upper = false;
2204
2205 for (i, ch) in s.chars().enumerate() {
2206 if ch.is_uppercase() {
2207 if i > 0 && !prev_upper {
2208 result.push('_');
2209 }
2210 result.push(ch.to_lowercase().next().unwrap());
2211 prev_upper = true;
2212 } else {
2213 result.push(ch);
2214 prev_upper = false;
2215 }
2216 }
2217
2218 result
2219}
2220
2221#[derive(Debug, Default)]
2223struct ActionAttrs {
2224 tag: Option<String>,
2225}
2226
2227fn parse_action_attrs(attrs: &[syn::Attribute]) -> ActionAttrs {
2228 let mut result = ActionAttrs::default();
2229
2230 for attr in attrs {
2231 if attr.path().is_ident("action")
2232 && let Ok(meta_list) = attr.meta.require_list()
2233 && let Ok(metas) =
2234 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2235 {
2236 for meta in metas {
2237 if let Meta::NameValue(nv) = meta
2238 && nv.path.is_ident("tag")
2239 && let syn::Expr::Lit(syn::ExprLit {
2240 lit: syn::Lit::Str(lit_str),
2241 ..
2242 }) = nv.value
2243 {
2244 result.tag = Some(lit_str.value());
2245 }
2246 }
2247 }
2248 }
2249
2250 result
2251}
2252
2253#[derive(Debug, Default)]
2255struct FieldActionAttrs {
2256 is_attribute: bool,
2257 is_inner_text: bool,
2258}
2259
2260fn parse_field_action_attrs(attrs: &[syn::Attribute]) -> FieldActionAttrs {
2261 let mut result = FieldActionAttrs::default();
2262
2263 for attr in attrs {
2264 if attr.path().is_ident("action")
2265 && let Ok(meta_list) = attr.meta.require_list()
2266 {
2267 let tokens_str = meta_list.tokens.to_string();
2268 if tokens_str == "attribute" {
2269 result.is_attribute = true;
2270 } else if tokens_str == "inner_text" {
2271 result.is_inner_text = true;
2272 }
2273 }
2274 }
2275
2276 result
2277}
2278
2279fn generate_multi_tag_actions_doc(
2281 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2282) -> String {
2283 let mut doc_lines = Vec::new();
2284
2285 for variant in variants {
2286 let action_attrs = parse_action_attrs(&variant.attrs);
2287
2288 if let Some(tag) = action_attrs.tag {
2289 let variant_docs = extract_doc_comments(&variant.attrs);
2290
2291 match &variant.fields {
2292 syn::Fields::Unit => {
2293 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2295 }
2296 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
2297 doc_lines.push(format!("- `<{}>...</{}>`: {}", tag, tag, variant_docs));
2299 }
2300 syn::Fields::Named(fields) => {
2301 let mut attrs_str = Vec::new();
2303 let mut has_inner_text = false;
2304
2305 for field in &fields.named {
2306 let field_name = field.ident.as_ref().unwrap();
2307 let field_attrs = parse_field_action_attrs(&field.attrs);
2308
2309 if field_attrs.is_attribute {
2310 attrs_str.push(format!("{}=\"...\"", field_name));
2311 } else if field_attrs.is_inner_text {
2312 has_inner_text = true;
2313 }
2314 }
2315
2316 let attrs_part = if !attrs_str.is_empty() {
2317 format!(" {}", attrs_str.join(" "))
2318 } else {
2319 String::new()
2320 };
2321
2322 if has_inner_text {
2323 doc_lines.push(format!(
2324 "- `<{}{}>...</{}>`: {}",
2325 tag, attrs_part, tag, variant_docs
2326 ));
2327 } else if !attrs_str.is_empty() {
2328 doc_lines.push(format!("- `<{}{} />`: {}", tag, attrs_part, variant_docs));
2329 } else {
2330 doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
2331 }
2332
2333 for field in &fields.named {
2335 let field_name = field.ident.as_ref().unwrap();
2336 let field_attrs = parse_field_action_attrs(&field.attrs);
2337 let field_docs = extract_doc_comments(&field.attrs);
2338
2339 if field_attrs.is_attribute {
2340 doc_lines
2341 .push(format!(" - `{}` (attribute): {}", field_name, field_docs));
2342 } else if field_attrs.is_inner_text {
2343 doc_lines
2344 .push(format!(" - `{}` (inner_text): {}", field_name, field_docs));
2345 }
2346 }
2347 }
2348 _ => {
2349 }
2351 }
2352 }
2353 }
2354
2355 doc_lines.join("\n")
2356}
2357
2358fn generate_tags_regex(
2360 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2361) -> String {
2362 let mut tag_names = Vec::new();
2363
2364 for variant in variants {
2365 let action_attrs = parse_action_attrs(&variant.attrs);
2366 if let Some(tag) = action_attrs.tag {
2367 tag_names.push(tag);
2368 }
2369 }
2370
2371 if tag_names.is_empty() {
2372 return String::new();
2373 }
2374
2375 let tags_pattern = tag_names.join("|");
2376 format!(
2379 r"(?is)<(?:{})\b[^>]*/>|<(?:{})\b[^>]*>.*?</(?:{})>",
2380 tags_pattern, tags_pattern, tags_pattern
2381 )
2382}
2383
2384fn generate_multi_tag_output(
2386 input: &DeriveInput,
2387 enum_name: &syn::Ident,
2388 enum_data: &syn::DataEnum,
2389 prompt_template: String,
2390 actions_doc: String,
2391) -> TokenStream {
2392 let found_crate =
2393 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2394 let crate_path = match found_crate {
2395 FoundCrate::Itself => {
2396 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2398 quote!(::#ident)
2399 }
2400 FoundCrate::Name(name) => {
2401 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2402 quote!(::#ident)
2403 }
2404 };
2405
2406 let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2408 let user_variables: Vec<String> = placeholders
2409 .iter()
2410 .filter_map(|(name, _)| {
2411 if name != "actions_doc" {
2412 Some(name.clone())
2413 } else {
2414 None
2415 }
2416 })
2417 .collect();
2418
2419 let enum_name_str = enum_name.to_string();
2421 let snake_case_name = to_snake_case(&enum_name_str);
2422 let function_name = syn::Ident::new(
2423 &format!("build_{}_prompt", snake_case_name),
2424 proc_macro2::Span::call_site(),
2425 );
2426
2427 let function_params: Vec<proc_macro2::TokenStream> = user_variables
2429 .iter()
2430 .map(|var| {
2431 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2432 quote! { #ident: &str }
2433 })
2434 .collect();
2435
2436 let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2438 .iter()
2439 .map(|var| {
2440 let var_str = var.clone();
2441 let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2442 quote! {
2443 __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2444 }
2445 })
2446 .collect();
2447
2448 let extractor_name = syn::Ident::new(
2450 &format!("{}Extractor", enum_name),
2451 proc_macro2::Span::call_site(),
2452 );
2453
2454 let filtered_attrs: Vec<_> = input
2456 .attrs
2457 .iter()
2458 .filter(|attr| !attr.path().is_ident("intent"))
2459 .collect();
2460
2461 let filtered_variants: Vec<proc_macro2::TokenStream> = enum_data
2463 .variants
2464 .iter()
2465 .map(|variant| {
2466 let variant_name = &variant.ident;
2467 let variant_attrs: Vec<_> = variant
2468 .attrs
2469 .iter()
2470 .filter(|attr| !attr.path().is_ident("action"))
2471 .collect();
2472 let fields = &variant.fields;
2473
2474 let filtered_fields = match fields {
2476 syn::Fields::Named(named_fields) => {
2477 let filtered: Vec<_> = named_fields
2478 .named
2479 .iter()
2480 .map(|field| {
2481 let field_name = &field.ident;
2482 let field_type = &field.ty;
2483 let field_vis = &field.vis;
2484 let filtered_attrs: Vec<_> = field
2485 .attrs
2486 .iter()
2487 .filter(|attr| !attr.path().is_ident("action"))
2488 .collect();
2489 quote! {
2490 #(#filtered_attrs)*
2491 #field_vis #field_name: #field_type
2492 }
2493 })
2494 .collect();
2495 quote! { { #(#filtered,)* } }
2496 }
2497 syn::Fields::Unnamed(unnamed_fields) => {
2498 let types: Vec<_> = unnamed_fields
2499 .unnamed
2500 .iter()
2501 .map(|field| {
2502 let field_type = &field.ty;
2503 quote! { #field_type }
2504 })
2505 .collect();
2506 quote! { (#(#types),*) }
2507 }
2508 syn::Fields::Unit => quote! {},
2509 };
2510
2511 quote! {
2512 #(#variant_attrs)*
2513 #variant_name #filtered_fields
2514 }
2515 })
2516 .collect();
2517
2518 let vis = &input.vis;
2519 let generics = &input.generics;
2520
2521 let parsing_arms = generate_parsing_arms(&enum_data.variants, enum_name);
2523
2524 let tags_regex = generate_tags_regex(&enum_data.variants);
2526
2527 let expanded = quote! {
2528 #(#filtered_attrs)*
2530 #vis enum #enum_name #generics {
2531 #(#filtered_variants),*
2532 }
2533
2534 pub fn #function_name(#(#function_params),*) -> String {
2536 let mut env = minijinja::Environment::new();
2537 env.add_template("prompt", #prompt_template)
2538 .expect("Failed to parse intent prompt template");
2539
2540 let tmpl = env.get_template("prompt").unwrap();
2541
2542 let mut __template_context = std::collections::HashMap::new();
2543
2544 __template_context.insert("actions_doc".to_string(), minijinja::Value::from(#actions_doc));
2546
2547 #(#context_insertions)*
2549
2550 tmpl.render(&__template_context)
2551 .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2552 }
2553
2554 pub struct #extractor_name;
2556
2557 impl #extractor_name {
2558 fn parse_single_action(&self, text: &str) -> Option<#enum_name> {
2559 use ::quick_xml::events::Event;
2560 use ::quick_xml::Reader;
2561
2562 let mut actions = Vec::new();
2563 let mut reader = Reader::from_str(text);
2564 reader.config_mut().trim_text(true);
2565
2566 let mut buf = Vec::new();
2567
2568 loop {
2569 match reader.read_event_into(&mut buf) {
2570 Ok(Event::Start(e)) => {
2571 let owned_e = e.into_owned();
2572 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2573 let is_empty = false;
2574
2575 #parsing_arms
2576 }
2577 Ok(Event::Empty(e)) => {
2578 let owned_e = e.into_owned();
2579 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2580 let is_empty = true;
2581
2582 #parsing_arms
2583 }
2584 Ok(Event::Eof) => break,
2585 Err(_) => {
2586 break;
2588 }
2589 _ => {}
2590 }
2591 buf.clear();
2592 }
2593
2594 actions.into_iter().next()
2595 }
2596
2597 pub fn extract_actions(&self, text: &str) -> Result<Vec<#enum_name>, #crate_path::intent::IntentError> {
2598 use ::quick_xml::events::Event;
2599 use ::quick_xml::Reader;
2600
2601 let mut actions = Vec::new();
2602 let mut reader = Reader::from_str(text);
2603 reader.config_mut().trim_text(true);
2604
2605 let mut buf = Vec::new();
2606
2607 loop {
2608 match reader.read_event_into(&mut buf) {
2609 Ok(Event::Start(e)) => {
2610 let owned_e = e.into_owned();
2611 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2612 let is_empty = false;
2613
2614 #parsing_arms
2615 }
2616 Ok(Event::Empty(e)) => {
2617 let owned_e = e.into_owned();
2618 let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2619 let is_empty = true;
2620
2621 #parsing_arms
2622 }
2623 Ok(Event::Eof) => break,
2624 Err(_) => {
2625 break;
2627 }
2628 _ => {}
2629 }
2630 buf.clear();
2631 }
2632
2633 Ok(actions)
2634 }
2635
2636 pub fn transform_actions<F>(&self, text: &str, mut transformer: F) -> String
2637 where
2638 F: FnMut(#enum_name) -> String,
2639 {
2640 use ::regex::Regex;
2641
2642 let regex_pattern = #tags_regex;
2643 if regex_pattern.is_empty() {
2644 return text.to_string();
2645 }
2646
2647 let re = Regex::new(®ex_pattern).unwrap_or_else(|e| {
2648 panic!("Failed to compile regex for action tags: {}", e);
2649 });
2650
2651 re.replace_all(text, |caps: &::regex::Captures| {
2652 let matched = caps.get(0).map(|m| m.as_str()).unwrap_or("");
2653
2654 if let Some(action) = self.parse_single_action(matched) {
2656 transformer(action)
2657 } else {
2658 matched.to_string()
2660 }
2661 }).to_string()
2662 }
2663
2664 pub fn strip_actions(&self, text: &str) -> String {
2665 self.transform_actions(text, |_| String::new())
2666 }
2667 }
2668 };
2669
2670 TokenStream::from(expanded)
2671}
2672
2673fn generate_parsing_arms(
2675 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2676 enum_name: &syn::Ident,
2677) -> proc_macro2::TokenStream {
2678 let mut arms = Vec::new();
2679
2680 for variant in variants {
2681 let variant_name = &variant.ident;
2682 let action_attrs = parse_action_attrs(&variant.attrs);
2683
2684 if let Some(tag) = action_attrs.tag {
2685 match &variant.fields {
2686 syn::Fields::Unit => {
2687 arms.push(quote! {
2689 if &tag_name == #tag {
2690 actions.push(#enum_name::#variant_name);
2691 }
2692 });
2693 }
2694 syn::Fields::Unnamed(_fields) => {
2695 arms.push(quote! {
2697 if &tag_name == #tag && !is_empty {
2698 match reader.read_text(owned_e.name()) {
2700 Ok(text) => {
2701 actions.push(#enum_name::#variant_name(text.to_string()));
2702 }
2703 Err(_) => {
2704 actions.push(#enum_name::#variant_name(String::new()));
2706 }
2707 }
2708 }
2709 });
2710 }
2711 syn::Fields::Named(fields) => {
2712 let mut field_names = Vec::new();
2714 let mut has_inner_text_field = None;
2715
2716 for field in &fields.named {
2717 let field_name = field.ident.as_ref().unwrap();
2718 let field_attrs = parse_field_action_attrs(&field.attrs);
2719
2720 if field_attrs.is_attribute {
2721 field_names.push(field_name.clone());
2722 } else if field_attrs.is_inner_text {
2723 has_inner_text_field = Some(field_name.clone());
2724 }
2725 }
2726
2727 if let Some(inner_text_field) = has_inner_text_field {
2728 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2731 quote! {
2732 let mut #field_name = String::new();
2733 for attr in owned_e.attributes() {
2734 if let Ok(attr) = attr {
2735 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2736 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2737 break;
2738 }
2739 }
2740 }
2741 }
2742 }).collect();
2743
2744 arms.push(quote! {
2745 if &tag_name == #tag {
2746 #(#attr_extractions)*
2747
2748 if is_empty {
2750 let #inner_text_field = String::new();
2751 actions.push(#enum_name::#variant_name {
2752 #(#field_names,)*
2753 #inner_text_field,
2754 });
2755 } else {
2756 match reader.read_text(owned_e.name()) {
2758 Ok(text) => {
2759 let #inner_text_field = text.to_string();
2760 actions.push(#enum_name::#variant_name {
2761 #(#field_names,)*
2762 #inner_text_field,
2763 });
2764 }
2765 Err(_) => {
2766 let #inner_text_field = String::new();
2768 actions.push(#enum_name::#variant_name {
2769 #(#field_names,)*
2770 #inner_text_field,
2771 });
2772 }
2773 }
2774 }
2775 }
2776 });
2777 } else {
2778 let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2780 quote! {
2781 let mut #field_name = String::new();
2782 for attr in owned_e.attributes() {
2783 if let Ok(attr) = attr {
2784 if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2785 #field_name = String::from_utf8_lossy(&attr.value).to_string();
2786 break;
2787 }
2788 }
2789 }
2790 }
2791 }).collect();
2792
2793 arms.push(quote! {
2794 if &tag_name == #tag {
2795 #(#attr_extractions)*
2796 actions.push(#enum_name::#variant_name {
2797 #(#field_names),*
2798 });
2799 }
2800 });
2801 }
2802 }
2803 }
2804 }
2805 }
2806
2807 quote! {
2808 #(#arms)*
2809 }
2810}
2811
2812#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
2814pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
2815 let input = parse_macro_input!(input as DeriveInput);
2816
2817 let found_crate =
2818 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
2819 let crate_path = match found_crate {
2820 FoundCrate::Itself => {
2821 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
2823 quote!(::#ident)
2824 }
2825 FoundCrate::Name(name) => {
2826 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
2827 quote!(::#ident)
2828 }
2829 };
2830
2831 let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
2833
2834 let struct_name = &input.ident;
2835 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2836
2837 let placeholders = parse_template_placeholders_with_mode(&template);
2839
2840 let mut converted_template = template.clone();
2842 let mut context_fields = Vec::new();
2843
2844 let fields = match &input.data {
2846 Data::Struct(data_struct) => match &data_struct.fields {
2847 syn::Fields::Named(fields) => &fields.named,
2848 _ => panic!("ToPromptFor is only supported for structs with named fields"),
2849 },
2850 _ => panic!("ToPromptFor is only supported for structs"),
2851 };
2852
2853 let has_mode_support = input.attrs.iter().any(|attr| {
2855 if attr.path().is_ident("prompt")
2856 && let Ok(metas) =
2857 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2858 {
2859 for meta in metas {
2860 if let Meta::NameValue(nv) = meta
2861 && nv.path.is_ident("mode")
2862 {
2863 return true;
2864 }
2865 }
2866 }
2867 false
2868 });
2869
2870 for (placeholder_name, mode_opt) in &placeholders {
2872 if placeholder_name == "self" {
2873 if let Some(specific_mode) = mode_opt {
2874 let unique_key = format!("self__{}", specific_mode);
2876
2877 let pattern = format!("{{{{ self:{} }}}}", specific_mode);
2879 let replacement = format!("{{{{ {} }}}}", unique_key);
2880 converted_template = converted_template.replace(&pattern, &replacement);
2881
2882 context_fields.push(quote! {
2884 context.insert(
2885 #unique_key.to_string(),
2886 minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
2887 );
2888 });
2889 } else {
2890 if has_mode_support {
2893 context_fields.push(quote! {
2895 context.insert(
2896 "self".to_string(),
2897 minijinja::Value::from(self.to_prompt_with_mode(mode))
2898 );
2899 });
2900 } else {
2901 context_fields.push(quote! {
2903 context.insert(
2904 "self".to_string(),
2905 minijinja::Value::from(self.to_prompt())
2906 );
2907 });
2908 }
2909 }
2910 } else {
2911 let field_exists = fields.iter().any(|f| {
2914 f.ident
2915 .as_ref()
2916 .is_some_and(|ident| ident == placeholder_name)
2917 });
2918
2919 if field_exists {
2920 let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
2921
2922 context_fields.push(quote! {
2926 context.insert(
2927 #placeholder_name.to_string(),
2928 minijinja::Value::from_serialize(&self.#field_ident)
2929 );
2930 });
2931 }
2932 }
2934 }
2935
2936 let expanded = quote! {
2937 impl #impl_generics #crate_path::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
2938 where
2939 #target_type: serde::Serialize,
2940 {
2941 fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
2942 let mut env = minijinja::Environment::new();
2944 env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
2945 panic!("Failed to parse template: {}", e)
2946 });
2947
2948 let tmpl = env.get_template("prompt").unwrap();
2949
2950 let mut context = std::collections::HashMap::new();
2952 context.insert(
2954 "self".to_string(),
2955 minijinja::Value::from_serialize(self)
2956 );
2957 context.insert(
2959 "target".to_string(),
2960 minijinja::Value::from_serialize(target)
2961 );
2962 #(#context_fields)*
2963
2964 tmpl.render(context).unwrap_or_else(|e| {
2966 format!("Failed to render prompt: {}", e)
2967 })
2968 }
2969 }
2970 };
2971
2972 TokenStream::from(expanded)
2973}
2974
2975struct AgentAttrs {
2981 expertise: Option<String>,
2982 output: Option<syn::Type>,
2983 backend: Option<String>,
2984 model: Option<String>,
2985 inner: Option<String>,
2986 default_inner: Option<String>,
2987 max_retries: Option<u32>,
2988 profile: Option<String>,
2989}
2990
2991impl Parse for AgentAttrs {
2992 fn parse(input: ParseStream) -> syn::Result<Self> {
2993 let mut expertise = None;
2994 let mut output = None;
2995 let mut backend = None;
2996 let mut model = None;
2997 let mut inner = None;
2998 let mut default_inner = None;
2999 let mut max_retries = None;
3000 let mut profile = None;
3001
3002 let pairs = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
3003
3004 for meta in pairs {
3005 match meta {
3006 Meta::NameValue(nv) if nv.path.is_ident("expertise") => {
3007 if let syn::Expr::Lit(syn::ExprLit {
3008 lit: syn::Lit::Str(lit_str),
3009 ..
3010 }) = &nv.value
3011 {
3012 expertise = Some(lit_str.value());
3013 }
3014 }
3015 Meta::NameValue(nv) if nv.path.is_ident("output") => {
3016 if let syn::Expr::Lit(syn::ExprLit {
3017 lit: syn::Lit::Str(lit_str),
3018 ..
3019 }) = &nv.value
3020 {
3021 let ty: syn::Type = syn::parse_str(&lit_str.value())?;
3022 output = Some(ty);
3023 }
3024 }
3025 Meta::NameValue(nv) if nv.path.is_ident("backend") => {
3026 if let syn::Expr::Lit(syn::ExprLit {
3027 lit: syn::Lit::Str(lit_str),
3028 ..
3029 }) = &nv.value
3030 {
3031 backend = Some(lit_str.value());
3032 }
3033 }
3034 Meta::NameValue(nv) if nv.path.is_ident("model") => {
3035 if let syn::Expr::Lit(syn::ExprLit {
3036 lit: syn::Lit::Str(lit_str),
3037 ..
3038 }) = &nv.value
3039 {
3040 model = Some(lit_str.value());
3041 }
3042 }
3043 Meta::NameValue(nv) if nv.path.is_ident("inner") => {
3044 if let syn::Expr::Lit(syn::ExprLit {
3045 lit: syn::Lit::Str(lit_str),
3046 ..
3047 }) = &nv.value
3048 {
3049 inner = Some(lit_str.value());
3050 }
3051 }
3052 Meta::NameValue(nv) if nv.path.is_ident("default_inner") => {
3053 if let syn::Expr::Lit(syn::ExprLit {
3054 lit: syn::Lit::Str(lit_str),
3055 ..
3056 }) = &nv.value
3057 {
3058 default_inner = Some(lit_str.value());
3059 }
3060 }
3061 Meta::NameValue(nv) if nv.path.is_ident("max_retries") => {
3062 if let syn::Expr::Lit(syn::ExprLit {
3063 lit: syn::Lit::Int(lit_int),
3064 ..
3065 }) = &nv.value
3066 {
3067 max_retries = Some(lit_int.base10_parse()?);
3068 }
3069 }
3070 Meta::NameValue(nv) if nv.path.is_ident("profile") => {
3071 if let syn::Expr::Lit(syn::ExprLit {
3072 lit: syn::Lit::Str(lit_str),
3073 ..
3074 }) = &nv.value
3075 {
3076 profile = Some(lit_str.value());
3077 }
3078 }
3079 _ => {}
3080 }
3081 }
3082
3083 Ok(AgentAttrs {
3084 expertise,
3085 output,
3086 backend,
3087 model,
3088 inner,
3089 default_inner,
3090 max_retries,
3091 profile,
3092 })
3093 }
3094}
3095
3096fn parse_agent_attrs(attrs: &[syn::Attribute]) -> syn::Result<AgentAttrs> {
3098 for attr in attrs {
3099 if attr.path().is_ident("agent") {
3100 return attr.parse_args::<AgentAttrs>();
3101 }
3102 }
3103
3104 Ok(AgentAttrs {
3105 expertise: None,
3106 output: None,
3107 backend: None,
3108 model: None,
3109 inner: None,
3110 default_inner: None,
3111 max_retries: None,
3112 profile: None,
3113 })
3114}
3115
3116fn generate_backend_constructors(
3118 struct_name: &syn::Ident,
3119 backend: &str,
3120 _model: Option<&str>,
3121 _profile: Option<&str>,
3122 crate_path: &proc_macro2::TokenStream,
3123) -> proc_macro2::TokenStream {
3124 match backend {
3125 "claude" => {
3126 quote! {
3127 impl #struct_name {
3128 pub fn with_claude() -> Self {
3130 Self::new(#crate_path::agent::impls::ClaudeCodeAgent::new())
3131 }
3132
3133 pub fn with_claude_model(model: &str) -> Self {
3135 Self::new(
3136 #crate_path::agent::impls::ClaudeCodeAgent::new()
3137 .with_model_str(model)
3138 )
3139 }
3140 }
3141 }
3142 }
3143 "gemini" => {
3144 quote! {
3145 impl #struct_name {
3146 pub fn with_gemini() -> Self {
3148 Self::new(#crate_path::agent::impls::GeminiAgent::new())
3149 }
3150
3151 pub fn with_gemini_model(model: &str) -> Self {
3153 Self::new(
3154 #crate_path::agent::impls::GeminiAgent::new()
3155 .with_model_str(model)
3156 )
3157 }
3158 }
3159 }
3160 }
3161 _ => quote! {},
3162 }
3163}
3164
3165fn generate_default_impl(
3167 struct_name: &syn::Ident,
3168 backend: &str,
3169 model: Option<&str>,
3170 profile: Option<&str>,
3171 crate_path: &proc_macro2::TokenStream,
3172) -> proc_macro2::TokenStream {
3173 let profile_expr = if let Some(profile_str) = profile {
3175 match profile_str.to_lowercase().as_str() {
3176 "creative" => quote! { #crate_path::agent::ExecutionProfile::Creative },
3177 "balanced" => quote! { #crate_path::agent::ExecutionProfile::Balanced },
3178 "deterministic" => quote! { #crate_path::agent::ExecutionProfile::Deterministic },
3179 _ => quote! { #crate_path::agent::ExecutionProfile::Balanced }, }
3181 } else {
3182 quote! { #crate_path::agent::ExecutionProfile::default() }
3183 };
3184
3185 let agent_init = match backend {
3186 "gemini" => {
3187 let mut builder = quote! { #crate_path::agent::impls::GeminiAgent::new() };
3188
3189 if let Some(model_str) = model {
3190 builder = quote! { #builder.with_model_str(#model_str) };
3191 }
3192
3193 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3194 builder
3195 }
3196 _ => {
3197 let mut builder = quote! { #crate_path::agent::impls::ClaudeCodeAgent::new() };
3199
3200 if let Some(model_str) = model {
3201 builder = quote! { #builder.with_model_str(#model_str) };
3202 }
3203
3204 builder = quote! { #builder.with_execution_profile(#profile_expr) };
3205 builder
3206 }
3207 };
3208
3209 quote! {
3210 impl Default for #struct_name {
3211 fn default() -> Self {
3212 Self::new(#agent_init)
3213 }
3214 }
3215 }
3216}
3217
3218#[proc_macro_derive(Agent, attributes(agent))]
3227pub fn derive_agent(input: TokenStream) -> TokenStream {
3228 let input = parse_macro_input!(input as DeriveInput);
3229 let struct_name = &input.ident;
3230
3231 let agent_attrs = match parse_agent_attrs(&input.attrs) {
3233 Ok(attrs) => attrs,
3234 Err(e) => return e.to_compile_error().into(),
3235 };
3236
3237 let expertise = agent_attrs
3238 .expertise
3239 .unwrap_or_else(|| String::from("general AI assistant"));
3240 let output_type = agent_attrs
3241 .output
3242 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3243 let backend = agent_attrs
3244 .backend
3245 .unwrap_or_else(|| String::from("claude"));
3246 let model = agent_attrs.model;
3247 let _profile = agent_attrs.profile; let max_retries = agent_attrs.max_retries.unwrap_or(3); let found_crate =
3252 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3253 let crate_path = match found_crate {
3254 FoundCrate::Itself => {
3255 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3257 quote!(::#ident)
3258 }
3259 FoundCrate::Name(name) => {
3260 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3261 quote!(::#ident)
3262 }
3263 };
3264
3265 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
3266
3267 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3269 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3270
3271 let enhanced_expertise = if is_string_output {
3273 quote! { #expertise }
3275 } else {
3276 let type_name = quote!(#output_type).to_string();
3278 quote! {
3279 {
3280 use std::sync::OnceLock;
3281 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3282
3283 EXPERTISE_CACHE.get_or_init(|| {
3284 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3286
3287 if schema.is_empty() {
3288 format!(
3290 concat!(
3291 #expertise,
3292 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3293 "Do not include any text outside the JSON object."
3294 ),
3295 #type_name
3296 )
3297 } else {
3298 format!(
3300 concat!(
3301 #expertise,
3302 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3303 ),
3304 schema
3305 )
3306 }
3307 }).as_str()
3308 }
3309 }
3310 };
3311
3312 let agent_init = match backend.as_str() {
3314 "gemini" => {
3315 if let Some(model_str) = model {
3316 quote! {
3317 use #crate_path::agent::impls::GeminiAgent;
3318 let agent = GeminiAgent::new().with_model_str(#model_str);
3319 }
3320 } else {
3321 quote! {
3322 use #crate_path::agent::impls::GeminiAgent;
3323 let agent = GeminiAgent::new();
3324 }
3325 }
3326 }
3327 "claude" => {
3328 if let Some(model_str) = model {
3329 quote! {
3330 use #crate_path::agent::impls::ClaudeCodeAgent;
3331 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3332 }
3333 } else {
3334 quote! {
3335 use #crate_path::agent::impls::ClaudeCodeAgent;
3336 let agent = ClaudeCodeAgent::new();
3337 }
3338 }
3339 }
3340 _ => {
3341 if let Some(model_str) = model {
3343 quote! {
3344 use #crate_path::agent::impls::ClaudeCodeAgent;
3345 let agent = ClaudeCodeAgent::new().with_model_str(#model_str);
3346 }
3347 } else {
3348 quote! {
3349 use #crate_path::agent::impls::ClaudeCodeAgent;
3350 let agent = ClaudeCodeAgent::new();
3351 }
3352 }
3353 }
3354 };
3355
3356 let expanded = quote! {
3357 #[async_trait::async_trait]
3358 impl #impl_generics #crate_path::agent::Agent for #struct_name #ty_generics #where_clause {
3359 type Output = #output_type;
3360
3361 fn expertise(&self) -> &str {
3362 #enhanced_expertise
3363 }
3364
3365 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3366 #agent_init
3368
3369 let agent_ref = &agent;
3371 #crate_path::agent::retry::retry_execution(
3372 #max_retries,
3373 &intent,
3374 move |payload| {
3375 let payload = payload.clone();
3376 async move {
3377 let response = agent_ref.execute(payload).await?;
3379
3380 let json_str = #crate_path::extract_json(&response)
3382 .map_err(|e| #crate_path::agent::AgentError::ParseError {
3383 message: format!("Failed to extract JSON: {}", e),
3384 reason: #crate_path::agent::error::ParseErrorReason::MarkdownExtractionFailed,
3385 })?;
3386
3387 serde_json::from_str::<Self::Output>(&json_str)
3389 .map_err(|e| {
3390 let reason = if e.is_eof() {
3392 #crate_path::agent::error::ParseErrorReason::UnexpectedEof
3393 } else if e.is_syntax() {
3394 #crate_path::agent::error::ParseErrorReason::InvalidJson
3395 } else {
3396 #crate_path::agent::error::ParseErrorReason::SchemaMismatch
3397 };
3398
3399 #crate_path::agent::AgentError::ParseError {
3400 message: format!("Failed to parse JSON: {}", e),
3401 reason,
3402 }
3403 })
3404 }
3405 }
3406 ).await
3407 }
3408
3409 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3410 #agent_init
3412 agent.is_available().await
3413 }
3414 }
3415 };
3416
3417 TokenStream::from(expanded)
3418}
3419
3420#[proc_macro_attribute]
3435pub fn agent(attr: TokenStream, item: TokenStream) -> TokenStream {
3436 let agent_attrs = match syn::parse::<AgentAttrs>(attr) {
3438 Ok(attrs) => attrs,
3439 Err(e) => return e.to_compile_error().into(),
3440 };
3441
3442 let input = parse_macro_input!(item as DeriveInput);
3444 let struct_name = &input.ident;
3445 let vis = &input.vis;
3446
3447 let expertise = agent_attrs
3448 .expertise
3449 .unwrap_or_else(|| String::from("general AI assistant"));
3450 let output_type = agent_attrs
3451 .output
3452 .unwrap_or_else(|| syn::parse_str::<syn::Type>("String").unwrap());
3453 let backend = agent_attrs
3454 .backend
3455 .unwrap_or_else(|| String::from("claude"));
3456 let model = agent_attrs.model;
3457 let profile = agent_attrs.profile;
3458
3459 let output_type_str = quote!(#output_type).to_string().replace(" ", "");
3461 let is_string_output = output_type_str == "String" || output_type_str == "&str";
3462
3463 let found_crate =
3465 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3466 let crate_path = match found_crate {
3467 FoundCrate::Itself => {
3468 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3469 quote!(::#ident)
3470 }
3471 FoundCrate::Name(name) => {
3472 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3473 quote!(::#ident)
3474 }
3475 };
3476
3477 let inner_generic_name = agent_attrs.inner.unwrap_or_else(|| String::from("A"));
3479 let inner_generic_ident = syn::Ident::new(&inner_generic_name, proc_macro2::Span::call_site());
3480
3481 let default_agent_type = if let Some(ref custom_type) = agent_attrs.default_inner {
3483 let type_path: syn::Type =
3485 syn::parse_str(custom_type).expect("default_inner must be a valid type path");
3486 quote! { #type_path }
3487 } else {
3488 match backend.as_str() {
3490 "gemini" => quote! { #crate_path::agent::impls::GeminiAgent },
3491 _ => quote! { #crate_path::agent::impls::ClaudeCodeAgent },
3492 }
3493 };
3494
3495 let struct_def = quote! {
3497 #vis struct #struct_name<#inner_generic_ident = #default_agent_type> {
3498 inner: #inner_generic_ident,
3499 }
3500 };
3501
3502 let constructors = quote! {
3504 impl<#inner_generic_ident> #struct_name<#inner_generic_ident> {
3505 pub fn new(inner: #inner_generic_ident) -> Self {
3507 Self { inner }
3508 }
3509 }
3510 };
3511
3512 let (backend_constructors, default_impl) = if agent_attrs.default_inner.is_some() {
3514 let default_impl = quote! {
3516 impl Default for #struct_name {
3517 fn default() -> Self {
3518 Self {
3519 inner: <#default_agent_type as Default>::default(),
3520 }
3521 }
3522 }
3523 };
3524 (quote! {}, default_impl)
3525 } else {
3526 let backend_constructors = generate_backend_constructors(
3528 struct_name,
3529 &backend,
3530 model.as_deref(),
3531 profile.as_deref(),
3532 &crate_path,
3533 );
3534 let default_impl = generate_default_impl(
3535 struct_name,
3536 &backend,
3537 model.as_deref(),
3538 profile.as_deref(),
3539 &crate_path,
3540 );
3541 (backend_constructors, default_impl)
3542 };
3543
3544 let enhanced_expertise = if is_string_output {
3546 quote! { #expertise }
3548 } else {
3549 let type_name = quote!(#output_type).to_string();
3551 quote! {
3552 {
3553 use std::sync::OnceLock;
3554 static EXPERTISE_CACHE: OnceLock<String> = OnceLock::new();
3555
3556 EXPERTISE_CACHE.get_or_init(|| {
3557 let schema = <#output_type as #crate_path::prompt::ToPrompt>::prompt_schema();
3559
3560 if schema.is_empty() {
3561 format!(
3563 concat!(
3564 #expertise,
3565 "\n\nIMPORTANT: You must respond with valid JSON matching the {} type structure. ",
3566 "Do not include any text outside the JSON object."
3567 ),
3568 #type_name
3569 )
3570 } else {
3571 format!(
3573 concat!(
3574 #expertise,
3575 "\n\nIMPORTANT: Respond with valid JSON matching this schema:\n\n{}"
3576 ),
3577 schema
3578 )
3579 }
3580 }).as_str()
3581 }
3582 }
3583 };
3584
3585 let agent_impl = quote! {
3587 #[async_trait::async_trait]
3588 impl<#inner_generic_ident> #crate_path::agent::Agent for #struct_name<#inner_generic_ident>
3589 where
3590 #inner_generic_ident: #crate_path::agent::Agent<Output = String>,
3591 {
3592 type Output = #output_type;
3593
3594 fn expertise(&self) -> &str {
3595 #enhanced_expertise
3596 }
3597
3598 async fn execute(&self, intent: #crate_path::agent::Payload) -> Result<Self::Output, #crate_path::agent::AgentError> {
3599 let enhanced_payload = intent.prepend_text(self.expertise());
3601
3602 let response = self.inner.execute(enhanced_payload).await?;
3604
3605 let json_str = #crate_path::extract_json(&response)
3607 .map_err(|e| #crate_path::agent::AgentError::ParseError {
3608 message: e.to_string(),
3609 reason: #crate_path::agent::error::ParseErrorReason::MarkdownExtractionFailed,
3610 })?;
3611
3612 serde_json::from_str(&json_str).map_err(|e| {
3614 let reason = if e.is_eof() {
3615 #crate_path::agent::error::ParseErrorReason::UnexpectedEof
3616 } else if e.is_syntax() {
3617 #crate_path::agent::error::ParseErrorReason::InvalidJson
3618 } else {
3619 #crate_path::agent::error::ParseErrorReason::SchemaMismatch
3620 };
3621 #crate_path::agent::AgentError::ParseError {
3622 message: e.to_string(),
3623 reason,
3624 }
3625 })
3626 }
3627
3628 async fn is_available(&self) -> Result<(), #crate_path::agent::AgentError> {
3629 self.inner.is_available().await
3630 }
3631 }
3632 };
3633
3634 let expanded = quote! {
3635 #struct_def
3636 #constructors
3637 #backend_constructors
3638 #default_impl
3639 #agent_impl
3640 };
3641
3642 TokenStream::from(expanded)
3643}
3644
3645#[proc_macro_derive(TypeMarker)]
3667pub fn derive_type_marker(input: TokenStream) -> TokenStream {
3668 let input = parse_macro_input!(input as DeriveInput);
3669 let struct_name = &input.ident;
3670 let type_name_str = struct_name.to_string();
3671
3672 let found_crate =
3674 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3675 let crate_path = match found_crate {
3676 FoundCrate::Itself => {
3677 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3678 quote!(::#ident)
3679 }
3680 FoundCrate::Name(name) => {
3681 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3682 quote!(::#ident)
3683 }
3684 };
3685
3686 let expanded = quote! {
3687 impl #crate_path::orchestrator::TypeMarker for #struct_name {
3688 const TYPE_NAME: &'static str = #type_name_str;
3689 }
3690 };
3691
3692 TokenStream::from(expanded)
3693}
3694
3695#[proc_macro_attribute]
3731pub fn type_marker(_attr: TokenStream, item: TokenStream) -> TokenStream {
3732 let input = parse_macro_input!(item as syn::DeriveInput);
3733 let struct_name = &input.ident;
3734 let vis = &input.vis;
3735 let type_name_str = struct_name.to_string();
3736
3737 let default_fn_name = syn::Ident::new(
3739 &format!("default_{}_type", to_snake_case(&type_name_str)),
3740 struct_name.span(),
3741 );
3742
3743 let found_crate =
3745 crate_name("llm-toolkit").expect("llm-toolkit should be present in `Cargo.toml`");
3746 let crate_path = match found_crate {
3747 FoundCrate::Itself => {
3748 let ident = syn::Ident::new("llm_toolkit", proc_macro2::Span::call_site());
3749 quote!(::#ident)
3750 }
3751 FoundCrate::Name(name) => {
3752 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
3753 quote!(::#ident)
3754 }
3755 };
3756
3757 let fields = match &input.data {
3759 syn::Data::Struct(data_struct) => match &data_struct.fields {
3760 syn::Fields::Named(fields) => &fields.named,
3761 _ => {
3762 return syn::Error::new_spanned(
3763 struct_name,
3764 "type_marker only works with structs with named fields",
3765 )
3766 .to_compile_error()
3767 .into();
3768 }
3769 },
3770 _ => {
3771 return syn::Error::new_spanned(struct_name, "type_marker only works with structs")
3772 .to_compile_error()
3773 .into();
3774 }
3775 };
3776
3777 let mut new_fields = vec![];
3779
3780 let default_fn_name_str = default_fn_name.to_string();
3782 let default_fn_name_lit = syn::LitStr::new(&default_fn_name_str, default_fn_name.span());
3783
3784 new_fields.push(quote! {
3789 #[serde(default = #default_fn_name_lit)]
3790 __type: String
3791 });
3792
3793 for field in fields {
3795 new_fields.push(quote! { #field });
3796 }
3797
3798 let attrs = &input.attrs;
3800 let generics = &input.generics;
3801
3802 let expanded = quote! {
3803 fn #default_fn_name() -> String {
3805 #type_name_str.to_string()
3806 }
3807
3808 #(#attrs)*
3810 #vis struct #struct_name #generics {
3811 #(#new_fields),*
3812 }
3813
3814 impl #crate_path::orchestrator::TypeMarker for #struct_name {
3816 const TYPE_NAME: &'static str = #type_name_str;
3817 }
3818 };
3819
3820 TokenStream::from(expanded)
3821}