1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4 Data, DeriveInput, Meta, Token,
5 parse::{Parse, ParseStream},
6 parse_macro_input,
7 punctuated::Punctuated,
8};
9
10fn convert_to_minijinja_syntax(template: &str) -> String {
13 let mut result = String::new();
14 let mut chars = template.chars().peekable();
15
16 while let Some(ch) = chars.next() {
17 if ch == '{' {
18 if chars.peek() == Some(&'{') {
20 result.push(ch);
21 result.push(chars.next().unwrap());
22 } else {
23 result.push_str("{{");
25 }
26 } else if ch == '}' {
27 if chars.peek() == Some(&'}') {
29 result.push(ch);
30 result.push(chars.next().unwrap());
31 } else {
32 result.push_str("}}");
34 }
35 } else {
36 result.push(ch);
37 }
38 }
39
40 result
41}
42
43fn parse_template_placeholders(template: &str) -> Vec<(String, Option<String>)> {
46 let mut placeholders = Vec::new();
47 let mut chars = template.chars().peekable();
48
49 while let Some(ch) = chars.next() {
50 if ch == '{' {
51 if chars.peek() == Some(&'{') {
53 chars.next(); continue;
55 }
56
57 let mut placeholder = String::new();
59 for inner_ch in chars.by_ref() {
60 if inner_ch == '}' {
61 break;
62 }
63 placeholder.push(inner_ch);
64 }
65
66 if let Some(colon_pos) = placeholder.find(':') {
68 let field_name = placeholder[..colon_pos].trim().to_string();
69 let mode = placeholder[colon_pos + 1..].trim().to_string();
70 placeholders.push((field_name, Some(mode)));
71 } else {
72 placeholders.push((placeholder.trim().to_string(), None));
73 }
74 }
75 }
76
77 placeholders
78}
79
80fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
82 attrs
83 .iter()
84 .filter_map(|attr| {
85 if attr.path().is_ident("doc")
86 && let syn::Meta::NameValue(meta_name_value) = &attr.meta
87 && let syn::Expr::Lit(syn::ExprLit {
88 lit: syn::Lit::Str(lit_str),
89 ..
90 }) = &meta_name_value.value
91 {
92 return Some(lit_str.value());
93 }
94 None
95 })
96 .map(|s| s.trim().to_string())
97 .collect::<Vec<_>>()
98 .join(" ")
99}
100
101fn generate_example_only_parts(
103 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
104 has_default: bool,
105) -> proc_macro2::TokenStream {
106 let mut field_values = Vec::new();
107
108 for field in fields.iter() {
109 let field_name = field.ident.as_ref().unwrap();
110 let field_name_str = field_name.to_string();
111 let attrs = parse_field_prompt_attrs(&field.attrs);
112
113 if attrs.skip {
115 continue;
116 }
117
118 if let Some(example) = attrs.example {
120 field_values.push(quote! {
122 json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
123 });
124 } else if has_default {
125 field_values.push(quote! {
127 let default_value = serde_json::to_value(&default_instance.#field_name)
128 .unwrap_or(serde_json::Value::Null);
129 json_obj.insert(#field_name_str.to_string(), default_value);
130 });
131 } else {
132 field_values.push(quote! {
134 let value = serde_json::to_value(&self.#field_name)
135 .unwrap_or(serde_json::Value::Null);
136 json_obj.insert(#field_name_str.to_string(), value);
137 });
138 }
139 }
140
141 if has_default {
142 quote! {
143 {
144 let default_instance = Self::default();
145 let mut json_obj = serde_json::Map::new();
146 #(#field_values)*
147 let json_value = serde_json::Value::Object(json_obj);
148 let json_str = serde_json::to_string_pretty(&json_value)
149 .unwrap_or_else(|_| "{}".to_string());
150 vec![llm_toolkit::prompt::PromptPart::Text(json_str)]
151 }
152 }
153 } else {
154 quote! {
155 {
156 let mut json_obj = serde_json::Map::new();
157 #(#field_values)*
158 let json_value = serde_json::Value::Object(json_obj);
159 let json_str = serde_json::to_string_pretty(&json_value)
160 .unwrap_or_else(|_| "{}".to_string());
161 vec![llm_toolkit::prompt::PromptPart::Text(json_str)]
162 }
163 }
164 }
165}
166
167fn generate_schema_only_parts(
169 struct_name: &str,
170 struct_docs: &str,
171 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
172) -> proc_macro2::TokenStream {
173 let mut schema_lines = vec![];
174
175 if !struct_docs.is_empty() {
177 schema_lines.push(format!("### Schema for `{}`\n{}", struct_name, struct_docs));
178 } else {
179 schema_lines.push(format!("### Schema for `{}`", struct_name));
180 }
181
182 schema_lines.push("{".to_string());
183
184 for (i, field) in fields.iter().enumerate() {
186 let field_name = field.ident.as_ref().unwrap();
187 let attrs = parse_field_prompt_attrs(&field.attrs);
188
189 if attrs.skip {
191 continue;
192 }
193
194 let field_docs = extract_doc_comments(&field.attrs);
196
197 let type_str = format_type_for_schema(&field.ty);
199
200 let mut field_line = format!(" \"{}\": \"{}\"", field_name, type_str);
202
203 if !field_docs.is_empty() {
205 field_line.push_str(&format!(", // {}", field_docs));
206 }
207
208 let remaining_fields = fields
210 .iter()
211 .skip(i + 1)
212 .filter(|f| {
213 let attrs = parse_field_prompt_attrs(&f.attrs);
214 !attrs.skip
215 })
216 .count();
217
218 if remaining_fields > 0 {
219 field_line.push(',');
220 }
221
222 schema_lines.push(field_line);
223 }
224
225 schema_lines.push("}".to_string());
226
227 let schema_str = schema_lines.join("\n");
228
229 quote! {
230 vec![llm_toolkit::prompt::PromptPart::Text(#schema_str.to_string())]
231 }
232}
233
234fn format_type_for_schema(ty: &syn::Type) -> String {
236 match ty {
238 syn::Type::Path(type_path) => {
239 let path = &type_path.path;
240 if let Some(last_segment) = path.segments.last() {
241 let type_name = last_segment.ident.to_string();
242
243 if type_name == "Option"
245 && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
246 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
247 {
248 return format!("{} | null", format_type_for_schema(inner_type));
249 }
250
251 match type_name.as_str() {
253 "String" | "str" => "string".to_string(),
254 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
255 | "u64" | "u128" | "usize" => "number".to_string(),
256 "f32" | "f64" => "number".to_string(),
257 "bool" => "boolean".to_string(),
258 "Vec" => {
259 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
260 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
261 {
262 return format!("{}[]", format_type_for_schema(inner_type));
263 }
264 "array".to_string()
265 }
266 _ => type_name.to_lowercase(),
267 }
268 } else {
269 "unknown".to_string()
270 }
271 }
272 _ => "unknown".to_string(),
273 }
274}
275
276enum PromptAttribute {
278 Skip,
279 Description(String),
280 None,
281}
282
283fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
285 for attr in attrs {
286 if attr.path().is_ident("prompt") {
287 if let Ok(meta_list) = attr.meta.require_list() {
289 let tokens = &meta_list.tokens;
290 let tokens_str = tokens.to_string();
291 if tokens_str == "skip" {
292 return PromptAttribute::Skip;
293 }
294 }
295
296 if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
298 return PromptAttribute::Description(lit_str.value());
299 }
300 }
301 }
302 PromptAttribute::None
303}
304
305#[derive(Debug, Default)]
307struct FieldPromptAttrs {
308 skip: bool,
309 rename: Option<String>,
310 format_with: Option<String>,
311 image: bool,
312 example: Option<String>,
313}
314
315fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
317 let mut result = FieldPromptAttrs::default();
318
319 for attr in attrs {
320 if attr.path().is_ident("prompt") {
321 if let Ok(meta_list) = attr.meta.require_list() {
323 if let Ok(metas) =
325 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
326 {
327 for meta in metas {
328 match meta {
329 Meta::Path(path) if path.is_ident("skip") => {
330 result.skip = true;
331 }
332 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
333 if let syn::Expr::Lit(syn::ExprLit {
334 lit: syn::Lit::Str(lit_str),
335 ..
336 }) = nv.value
337 {
338 result.rename = Some(lit_str.value());
339 }
340 }
341 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
342 if let syn::Expr::Lit(syn::ExprLit {
343 lit: syn::Lit::Str(lit_str),
344 ..
345 }) = nv.value
346 {
347 result.format_with = Some(lit_str.value());
348 }
349 }
350 Meta::Path(path) if path.is_ident("image") => {
351 result.image = true;
352 }
353 Meta::NameValue(nv) if nv.path.is_ident("example") => {
354 if let syn::Expr::Lit(syn::ExprLit {
355 lit: syn::Lit::Str(lit_str),
356 ..
357 }) = nv.value
358 {
359 result.example = Some(lit_str.value());
360 }
361 }
362 _ => {}
363 }
364 }
365 } else if meta_list.tokens.to_string() == "skip" {
366 result.skip = true;
368 } else if meta_list.tokens.to_string() == "image" {
369 result.image = true;
371 }
372 }
373 }
374 }
375
376 result
377}
378
379#[proc_macro_derive(ToPrompt, attributes(prompt))]
422pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
423 let input = parse_macro_input!(input as DeriveInput);
424
425 match &input.data {
427 Data::Enum(data_enum) => {
428 let enum_name = &input.ident;
430 let enum_docs = extract_doc_comments(&input.attrs);
431
432 let mut prompt_lines = Vec::new();
433
434 if !enum_docs.is_empty() {
436 prompt_lines.push(format!("{}: {}", enum_name, enum_docs));
437 } else {
438 prompt_lines.push(format!("{}:", enum_name));
439 }
440 prompt_lines.push(String::new()); prompt_lines.push("Possible values:".to_string());
442
443 for variant in &data_enum.variants {
445 let variant_name = &variant.ident;
446
447 match parse_prompt_attribute(&variant.attrs) {
449 PromptAttribute::Skip => {
450 continue;
452 }
453 PromptAttribute::Description(desc) => {
454 prompt_lines.push(format!("- {}: {}", variant_name, desc));
456 }
457 PromptAttribute::None => {
458 let variant_docs = extract_doc_comments(&variant.attrs);
460 if !variant_docs.is_empty() {
461 prompt_lines.push(format!("- {}: {}", variant_name, variant_docs));
462 } else {
463 prompt_lines.push(format!("- {}", variant_name));
464 }
465 }
466 }
467 }
468
469 let prompt_string = prompt_lines.join("\n");
470 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
471
472 let expanded = quote! {
473 impl #impl_generics llm_toolkit::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
474 fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
475 vec![llm_toolkit::prompt::PromptPart::Text(#prompt_string.to_string())]
476 }
477
478 fn to_prompt(&self) -> String {
479 #prompt_string.to_string()
480 }
481 }
482 };
483
484 TokenStream::from(expanded)
485 }
486 Data::Struct(data_struct) => {
487 let mut template_attr = None;
489 let mut mode_attr = None;
490
491 for attr in &input.attrs {
492 if attr.path().is_ident("prompt") {
493 if let Ok(metas) =
495 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
496 {
497 for meta in metas {
498 match meta {
499 Meta::NameValue(nv) if nv.path.is_ident("template") => {
500 if let syn::Expr::Lit(expr_lit) = nv.value
501 && let syn::Lit::Str(lit_str) = expr_lit.lit
502 {
503 template_attr = Some(lit_str.value());
504 }
505 }
506 Meta::NameValue(nv) if nv.path.is_ident("mode") => {
507 if let syn::Expr::Lit(expr_lit) = nv.value
508 && let syn::Lit::Str(lit_str) = expr_lit.lit
509 {
510 mode_attr = Some(lit_str.value());
511 }
512 }
513 _ => {}
514 }
515 }
516 }
517 }
518 }
519
520 let name = input.ident;
521 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
522
523 let struct_docs = extract_doc_comments(&input.attrs);
525
526 let is_mode_based =
528 mode_attr.is_some() || (template_attr.is_none() && struct_docs.contains("mode"));
529
530 let expanded = if is_mode_based || mode_attr.is_some() {
531 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
533 &fields.named
534 } else {
535 panic!(
536 "Mode-based prompt generation is only supported for structs with named fields."
537 );
538 };
539
540 let struct_name_str = name.to_string();
541
542 let has_default = input.attrs.iter().any(|attr| {
544 if attr.path().is_ident("derive") {
545 if let Ok(meta_list) = attr.meta.require_list() {
546 let tokens_str = meta_list.tokens.to_string();
547 tokens_str.contains("Default")
548 } else {
549 false
550 }
551 } else {
552 false
553 }
554 });
555
556 let schema_parts =
558 generate_schema_only_parts(&struct_name_str, &struct_docs, fields);
559
560 let example_parts = generate_example_only_parts(fields, has_default);
562
563 quote! {
564 impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
565 fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<llm_toolkit::prompt::PromptPart> {
566 match mode {
567 "schema_only" => #schema_parts,
568 "example_only" => #example_parts,
569 "full" | _ => {
570 let mut parts = Vec::new();
572
573 let schema_parts = #schema_parts;
575 parts.extend(schema_parts);
576
577 parts.push(llm_toolkit::prompt::PromptPart::Text("\n### Example".to_string()));
579 parts.push(llm_toolkit::prompt::PromptPart::Text(
580 format!("Here is an example of a valid `{}` object:", #struct_name_str)
581 ));
582
583 let example_parts = #example_parts;
585 parts.extend(example_parts);
586
587 parts
588 }
589 }
590 }
591
592 fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
593 self.to_prompt_parts_with_mode("full")
594 }
595
596 fn to_prompt(&self) -> String {
597 self.to_prompt_parts()
598 .into_iter()
599 .filter_map(|part| match part {
600 llm_toolkit::prompt::PromptPart::Text(text) => Some(text),
601 _ => None,
602 })
603 .collect::<Vec<_>>()
604 .join("\n")
605 }
606 }
607 }
608 } else if let Some(template_str) = template_attr {
609 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
612 &fields.named
613 } else {
614 panic!(
615 "Template prompt generation is only supported for structs with named fields."
616 );
617 };
618
619 let placeholders = parse_template_placeholders(&template_str);
621 let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
623 mode.is_some()
624 && fields
625 .iter()
626 .any(|f| f.ident.as_ref().unwrap() == field_name)
627 });
628
629 let mut image_field_parts = Vec::new();
630 for f in fields.iter() {
631 let field_name = f.ident.as_ref().unwrap();
632 let attrs = parse_field_prompt_attrs(&f.attrs);
633
634 if attrs.image {
635 image_field_parts.push(quote! {
637 parts.extend(self.#field_name.to_prompt_parts());
638 });
639 }
640 }
641
642 if has_mode_syntax {
644 let mut context_fields = Vec::new();
646
647 let mut converted_template = template_str.clone();
650
651 for (field_name, mode_opt) in &placeholders {
653 let field_ident =
655 syn::Ident::new(field_name, proc_macro2::Span::call_site());
656
657 if let Some(mode) = mode_opt {
658 let unique_key = format!("{}__{}", field_name, mode);
660
661 let pattern = format!("{{{}:{}}}", field_name, mode);
663 let replacement = format!("{{{{{}}}}}", unique_key);
664 converted_template = converted_template.replace(&pattern, &replacement);
665
666 context_fields.push(quote! {
668 context.insert(
669 #unique_key.to_string(),
670 minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
671 );
672 });
673 } else {
674 let pattern = format!("{{{}}}", field_name);
676 let replacement = format!("{{{{{}}}}}", field_name);
677 converted_template = converted_template.replace(&pattern, &replacement);
678
679 context_fields.push(quote! {
681 context.insert(
682 #field_name.to_string(),
683 minijinja::Value::from(self.#field_ident.to_prompt())
684 );
685 });
686 }
687 }
688
689 quote! {
690 impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
691 fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
692 let mut parts = Vec::new();
693
694 #(#image_field_parts)*
696
697 let text = {
699 let mut env = minijinja::Environment::new();
700 env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
701 panic!("Failed to parse template: {}", e)
702 });
703
704 let tmpl = env.get_template("prompt").unwrap();
705
706 let mut context = std::collections::HashMap::new();
707 #(#context_fields)*
708
709 tmpl.render(context).unwrap_or_else(|e| {
710 format!("Failed to render prompt: {}", e)
711 })
712 };
713
714 if !text.is_empty() {
715 parts.push(llm_toolkit::prompt::PromptPart::Text(text));
716 }
717
718 parts
719 }
720
721 fn to_prompt(&self) -> String {
722 let mut env = minijinja::Environment::new();
724 env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
725 panic!("Failed to parse template: {}", e)
726 });
727
728 let tmpl = env.get_template("prompt").unwrap();
729
730 let mut context = std::collections::HashMap::new();
731 #(#context_fields)*
732
733 tmpl.render(context).unwrap_or_else(|e| {
734 format!("Failed to render prompt: {}", e)
735 })
736 }
737 }
738 }
739 } else {
740 let converted_template = convert_to_minijinja_syntax(&template_str);
742
743 quote! {
744 impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
745 fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
746 let mut parts = Vec::new();
747
748 #(#image_field_parts)*
750
751 let text = llm_toolkit::prompt::render_prompt(#converted_template, self).unwrap_or_else(|e| {
753 format!("Failed to render prompt: {}", e)
754 });
755 if !text.is_empty() {
756 parts.push(llm_toolkit::prompt::PromptPart::Text(text));
757 }
758
759 parts
760 }
761
762 fn to_prompt(&self) -> String {
763 llm_toolkit::prompt::render_prompt(#converted_template, self).unwrap_or_else(|e| {
764 format!("Failed to render prompt: {}", e)
765 })
766 }
767 }
768 }
769 }
770 } else {
771 let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
774 &fields.named
775 } else {
776 panic!(
777 "Default prompt generation is only supported for structs with named fields."
778 );
779 };
780
781 let mut text_field_parts = Vec::new();
783 let mut image_field_parts = Vec::new();
784
785 for f in fields.iter() {
786 let field_name = f.ident.as_ref().unwrap();
787 let attrs = parse_field_prompt_attrs(&f.attrs);
788
789 if attrs.skip {
791 continue;
792 }
793
794 if attrs.image {
795 image_field_parts.push(quote! {
797 parts.extend(self.#field_name.to_prompt_parts());
798 });
799 } else {
800 let key = if let Some(rename) = attrs.rename {
806 rename
807 } else {
808 let doc_comment = extract_doc_comments(&f.attrs);
809 if !doc_comment.is_empty() {
810 doc_comment
811 } else {
812 field_name.to_string()
813 }
814 };
815
816 let value_expr = if let Some(format_with) = attrs.format_with {
818 let func_path: syn::Path =
820 syn::parse_str(&format_with).unwrap_or_else(|_| {
821 panic!("Invalid function path: {}", format_with)
822 });
823 quote! { #func_path(&self.#field_name) }
824 } else {
825 quote! { self.#field_name.to_prompt() }
826 };
827
828 text_field_parts.push(quote! {
829 text_parts.push(format!("{}: {}", #key, #value_expr));
830 });
831 }
832 }
833
834 quote! {
836 impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
837 fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
838 let mut parts = Vec::new();
839
840 #(#image_field_parts)*
842
843 let mut text_parts = Vec::new();
845 #(#text_field_parts)*
846
847 if !text_parts.is_empty() {
848 parts.push(llm_toolkit::prompt::PromptPart::Text(text_parts.join("\n")));
849 }
850
851 parts
852 }
853
854 fn to_prompt(&self) -> String {
855 let mut text_parts = Vec::new();
856 #(#text_field_parts)*
857 text_parts.join("\n")
858 }
859 }
860 }
861 };
862
863 TokenStream::from(expanded)
864 }
865 Data::Union(_) => {
866 panic!("`#[derive(ToPrompt)]` is not supported for unions");
867 }
868 }
869}
870
871#[derive(Debug, Clone)]
873struct TargetInfo {
874 name: String,
875 template: Option<String>,
876 field_configs: std::collections::HashMap<String, FieldTargetConfig>,
877}
878
879#[derive(Debug, Clone, Default)]
881struct FieldTargetConfig {
882 skip: bool,
883 rename: Option<String>,
884 format_with: Option<String>,
885 image: bool,
886 include_only: bool, }
888
889fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
891 let mut configs = Vec::new();
892
893 for attr in attrs {
894 if attr.path().is_ident("prompt_for")
895 && let Ok(meta_list) = attr.meta.require_list()
896 {
897 if meta_list.tokens.to_string() == "skip" {
899 let config = FieldTargetConfig {
901 skip: true,
902 ..Default::default()
903 };
904 configs.push(("*".to_string(), config));
905 } else if let Ok(metas) =
906 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
907 {
908 let mut target_name = None;
909 let mut config = FieldTargetConfig::default();
910
911 for meta in metas {
912 match meta {
913 Meta::NameValue(nv) if nv.path.is_ident("name") => {
914 if let syn::Expr::Lit(syn::ExprLit {
915 lit: syn::Lit::Str(lit_str),
916 ..
917 }) = nv.value
918 {
919 target_name = Some(lit_str.value());
920 }
921 }
922 Meta::Path(path) if path.is_ident("skip") => {
923 config.skip = true;
924 }
925 Meta::NameValue(nv) if nv.path.is_ident("rename") => {
926 if let syn::Expr::Lit(syn::ExprLit {
927 lit: syn::Lit::Str(lit_str),
928 ..
929 }) = nv.value
930 {
931 config.rename = Some(lit_str.value());
932 }
933 }
934 Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
935 if let syn::Expr::Lit(syn::ExprLit {
936 lit: syn::Lit::Str(lit_str),
937 ..
938 }) = nv.value
939 {
940 config.format_with = Some(lit_str.value());
941 }
942 }
943 Meta::Path(path) if path.is_ident("image") => {
944 config.image = true;
945 }
946 _ => {}
947 }
948 }
949
950 if let Some(name) = target_name {
951 config.include_only = true;
952 configs.push((name, config));
953 }
954 }
955 }
956 }
957
958 configs
959}
960
961fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
963 let mut targets = Vec::new();
964
965 for attr in attrs {
966 if attr.path().is_ident("prompt_for")
967 && let Ok(meta_list) = attr.meta.require_list()
968 && let Ok(metas) =
969 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
970 {
971 let mut target_name = None;
972 let mut template = None;
973
974 for meta in metas {
975 match meta {
976 Meta::NameValue(nv) if nv.path.is_ident("name") => {
977 if let syn::Expr::Lit(syn::ExprLit {
978 lit: syn::Lit::Str(lit_str),
979 ..
980 }) = nv.value
981 {
982 target_name = Some(lit_str.value());
983 }
984 }
985 Meta::NameValue(nv) if nv.path.is_ident("template") => {
986 if let syn::Expr::Lit(syn::ExprLit {
987 lit: syn::Lit::Str(lit_str),
988 ..
989 }) = nv.value
990 {
991 template = Some(lit_str.value());
992 }
993 }
994 _ => {}
995 }
996 }
997
998 if let Some(name) = target_name {
999 targets.push(TargetInfo {
1000 name,
1001 template,
1002 field_configs: std::collections::HashMap::new(),
1003 });
1004 }
1005 }
1006 }
1007
1008 targets
1009}
1010
1011#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
1012pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
1013 let input = parse_macro_input!(input as DeriveInput);
1014
1015 let data_struct = match &input.data {
1017 Data::Struct(data) => data,
1018 _ => {
1019 return syn::Error::new(
1020 input.ident.span(),
1021 "`#[derive(ToPromptSet)]` is only supported for structs",
1022 )
1023 .to_compile_error()
1024 .into();
1025 }
1026 };
1027
1028 let fields = match &data_struct.fields {
1029 syn::Fields::Named(fields) => &fields.named,
1030 _ => {
1031 return syn::Error::new(
1032 input.ident.span(),
1033 "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
1034 )
1035 .to_compile_error()
1036 .into();
1037 }
1038 };
1039
1040 let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
1042
1043 for field in fields.iter() {
1045 let field_name = field.ident.as_ref().unwrap().to_string();
1046 let field_configs = parse_prompt_for_attrs(&field.attrs);
1047
1048 for (target_name, config) in field_configs {
1049 if target_name == "*" {
1050 for target in &mut targets {
1052 target
1053 .field_configs
1054 .entry(field_name.clone())
1055 .or_insert_with(FieldTargetConfig::default)
1056 .skip = config.skip;
1057 }
1058 } else {
1059 let target_exists = targets.iter().any(|t| t.name == target_name);
1061 if !target_exists {
1062 targets.push(TargetInfo {
1064 name: target_name.clone(),
1065 template: None,
1066 field_configs: std::collections::HashMap::new(),
1067 });
1068 }
1069
1070 let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
1071
1072 target.field_configs.insert(field_name.clone(), config);
1073 }
1074 }
1075 }
1076
1077 let mut match_arms = Vec::new();
1079
1080 for target in &targets {
1081 let target_name = &target.name;
1082
1083 if let Some(template_str) = &target.template {
1084 let mut image_parts = Vec::new();
1086
1087 for field in fields.iter() {
1088 let field_name = field.ident.as_ref().unwrap();
1089 let field_name_str = field_name.to_string();
1090
1091 if let Some(config) = target.field_configs.get(&field_name_str)
1092 && config.image
1093 {
1094 image_parts.push(quote! {
1095 parts.extend(self.#field_name.to_prompt_parts());
1096 });
1097 }
1098 }
1099
1100 match_arms.push(quote! {
1101 #target_name => {
1102 let mut parts = Vec::new();
1103
1104 #(#image_parts)*
1105
1106 let text = llm_toolkit::prompt::render_prompt(#template_str, self)
1107 .map_err(|e| llm_toolkit::prompt::PromptSetError::RenderFailed {
1108 target: #target_name.to_string(),
1109 source: e,
1110 })?;
1111
1112 if !text.is_empty() {
1113 parts.push(llm_toolkit::prompt::PromptPart::Text(text));
1114 }
1115
1116 Ok(parts)
1117 }
1118 });
1119 } else {
1120 let mut text_field_parts = Vec::new();
1122 let mut image_field_parts = Vec::new();
1123
1124 for field in fields.iter() {
1125 let field_name = field.ident.as_ref().unwrap();
1126 let field_name_str = field_name.to_string();
1127
1128 let config = target.field_configs.get(&field_name_str);
1130
1131 if let Some(cfg) = config
1133 && cfg.skip
1134 {
1135 continue;
1136 }
1137
1138 let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
1142 let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
1143 .iter()
1144 .any(|(name, _)| name != "*");
1145
1146 if has_any_target_specific_config && !is_explicitly_for_this_target {
1147 continue;
1148 }
1149
1150 if let Some(cfg) = config {
1151 if cfg.image {
1152 image_field_parts.push(quote! {
1153 parts.extend(self.#field_name.to_prompt_parts());
1154 });
1155 } else {
1156 let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
1157
1158 let value_expr = if let Some(format_with) = &cfg.format_with {
1159 match syn::parse_str::<syn::Path>(format_with) {
1161 Ok(func_path) => quote! { #func_path(&self.#field_name) },
1162 Err(_) => {
1163 let error_msg = format!(
1165 "Invalid function path in format_with: '{}'",
1166 format_with
1167 );
1168 quote! {
1169 compile_error!(#error_msg);
1170 String::new()
1171 }
1172 }
1173 }
1174 } else {
1175 quote! { self.#field_name.to_prompt() }
1176 };
1177
1178 text_field_parts.push(quote! {
1179 text_parts.push(format!("{}: {}", #key, #value_expr));
1180 });
1181 }
1182 } else {
1183 text_field_parts.push(quote! {
1185 text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
1186 });
1187 }
1188 }
1189
1190 match_arms.push(quote! {
1191 #target_name => {
1192 let mut parts = Vec::new();
1193
1194 #(#image_field_parts)*
1195
1196 let mut text_parts = Vec::new();
1197 #(#text_field_parts)*
1198
1199 if !text_parts.is_empty() {
1200 parts.push(llm_toolkit::prompt::PromptPart::Text(text_parts.join("\n")));
1201 }
1202
1203 Ok(parts)
1204 }
1205 });
1206 }
1207 }
1208
1209 let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
1211
1212 match_arms.push(quote! {
1214 _ => {
1215 let available = vec![#(#target_names.to_string()),*];
1216 Err(llm_toolkit::prompt::PromptSetError::TargetNotFound {
1217 target: target.to_string(),
1218 available,
1219 })
1220 }
1221 });
1222
1223 let struct_name = &input.ident;
1224 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1225
1226 let expanded = quote! {
1227 impl #impl_generics llm_toolkit::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
1228 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<llm_toolkit::prompt::PromptPart>, llm_toolkit::prompt::PromptSetError> {
1229 match target {
1230 #(#match_arms)*
1231 }
1232 }
1233 }
1234 };
1235
1236 TokenStream::from(expanded)
1237}
1238
1239struct TypeList {
1241 types: Punctuated<syn::Type, Token![,]>,
1242}
1243
1244impl Parse for TypeList {
1245 fn parse(input: ParseStream) -> syn::Result<Self> {
1246 Ok(TypeList {
1247 types: Punctuated::parse_terminated(input)?,
1248 })
1249 }
1250}
1251
1252#[proc_macro]
1276pub fn examples_section(input: TokenStream) -> TokenStream {
1277 let input = parse_macro_input!(input as TypeList);
1278
1279 let mut type_sections = Vec::new();
1281
1282 for ty in input.types.iter() {
1283 let type_name_str = quote!(#ty).to_string();
1285
1286 type_sections.push(quote! {
1288 {
1289 let type_name = #type_name_str;
1290 let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
1291 format!("---\n#### `{}`\n{}", type_name, json_example)
1292 }
1293 });
1294 }
1295
1296 let expanded = quote! {
1298 {
1299 let mut sections = Vec::new();
1300 sections.push("---".to_string());
1301 sections.push("### Examples".to_string());
1302 sections.push("".to_string());
1303 sections.push("Here are examples of the data structures you should use.".to_string());
1304 sections.push("".to_string());
1305
1306 #(sections.push(#type_sections);)*
1307
1308 sections.push("---".to_string());
1309
1310 sections.join("\n")
1311 }
1312 };
1313
1314 TokenStream::from(expanded)
1315}
1316
1317fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
1319 for attr in attrs {
1320 if attr.path().is_ident("prompt_for")
1321 && let Ok(meta_list) = attr.meta.require_list()
1322 && let Ok(metas) =
1323 meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1324 {
1325 let mut target_type = None;
1326 let mut template = None;
1327
1328 for meta in metas {
1329 match meta {
1330 Meta::NameValue(nv) if nv.path.is_ident("target") => {
1331 if let syn::Expr::Lit(syn::ExprLit {
1332 lit: syn::Lit::Str(lit_str),
1333 ..
1334 }) = nv.value
1335 {
1336 target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
1338 }
1339 }
1340 Meta::NameValue(nv) if nv.path.is_ident("template") => {
1341 if let syn::Expr::Lit(syn::ExprLit {
1342 lit: syn::Lit::Str(lit_str),
1343 ..
1344 }) = nv.value
1345 {
1346 template = Some(lit_str.value());
1347 }
1348 }
1349 _ => {}
1350 }
1351 }
1352
1353 if let (Some(target), Some(tmpl)) = (target_type, template) {
1354 return (target, tmpl);
1355 }
1356 }
1357 }
1358
1359 panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
1360}
1361
1362#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
1364pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
1365 let input = parse_macro_input!(input as DeriveInput);
1366
1367 let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
1369
1370 let struct_name = &input.ident;
1371 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1372
1373 let placeholders = parse_template_placeholders(&template);
1375
1376 let mut converted_template = template.clone();
1378 let mut context_fields = Vec::new();
1379
1380 let fields = match &input.data {
1382 Data::Struct(data_struct) => match &data_struct.fields {
1383 syn::Fields::Named(fields) => &fields.named,
1384 _ => panic!("ToPromptFor is only supported for structs with named fields"),
1385 },
1386 _ => panic!("ToPromptFor is only supported for structs"),
1387 };
1388
1389 let has_mode_support = input.attrs.iter().any(|attr| {
1391 if attr.path().is_ident("prompt")
1392 && let Ok(metas) =
1393 attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1394 {
1395 for meta in metas {
1396 if let Meta::NameValue(nv) = meta
1397 && nv.path.is_ident("mode")
1398 {
1399 return true;
1400 }
1401 }
1402 }
1403 false
1404 });
1405
1406 for (placeholder_name, mode_opt) in &placeholders {
1408 if placeholder_name == "self" {
1409 if let Some(specific_mode) = mode_opt {
1410 let unique_key = format!("self__{}", specific_mode);
1412
1413 let pattern = format!("{{self:{}}}", specific_mode);
1415 let replacement = format!("{{{{{}}}}}", unique_key);
1416 converted_template = converted_template.replace(&pattern, &replacement);
1417
1418 context_fields.push(quote! {
1420 context.insert(
1421 #unique_key.to_string(),
1422 minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
1423 );
1424 });
1425 } else {
1426 let pattern = "{self}";
1428 let replacement = "{{self}}";
1429 converted_template = converted_template.replace(pattern, replacement);
1430
1431 if has_mode_support {
1432 context_fields.push(quote! {
1434 context.insert(
1435 "self".to_string(),
1436 minijinja::Value::from(self.to_prompt_with_mode(mode))
1437 );
1438 });
1439 } else {
1440 context_fields.push(quote! {
1442 context.insert(
1443 "self".to_string(),
1444 minijinja::Value::from(self.to_prompt())
1445 );
1446 });
1447 }
1448 }
1449 } else {
1450 let field_exists = fields.iter().any(|f| {
1453 f.ident
1454 .as_ref()
1455 .is_some_and(|ident| ident == placeholder_name)
1456 });
1457
1458 if field_exists {
1459 let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
1460
1461 let pattern = format!("{{{}}}", placeholder_name);
1463 let replacement = format!("{{{{{}}}}}", placeholder_name);
1464 converted_template = converted_template.replace(&pattern, &replacement);
1465
1466 context_fields.push(quote! {
1468 context.insert(
1469 #placeholder_name.to_string(),
1470 minijinja::Value::from_serialize(&self.#field_ident)
1471 );
1472 });
1473 }
1474 }
1476 }
1477
1478 let expanded = quote! {
1479 impl #impl_generics llm_toolkit::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
1480 where
1481 #target_type: serde::Serialize,
1482 {
1483 fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
1484 let mut env = minijinja::Environment::new();
1486 env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
1487 panic!("Failed to parse template: {}", e)
1488 });
1489
1490 let tmpl = env.get_template("prompt").unwrap();
1491
1492 let mut context = std::collections::HashMap::new();
1494 context.insert(
1496 "self".to_string(),
1497 minijinja::Value::from_serialize(self)
1498 );
1499 context.insert(
1501 "target".to_string(),
1502 minijinja::Value::from_serialize(target)
1503 );
1504 #(#context_fields)*
1505
1506 tmpl.render(context).unwrap_or_else(|e| {
1508 format!("Failed to render prompt: {}", e)
1509 })
1510 }
1511 }
1512 };
1513
1514 TokenStream::from(expanded)
1515}