llm_toolkit_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use regex::Regex;
4use syn::{
5    Data, DeriveInput, Meta, Token,
6    parse::{Parse, ParseStream},
7    parse_macro_input,
8    punctuated::Punctuated,
9};
10
11/// Parse template placeholders using regex to find :mode patterns
12/// Returns a list of (field_name, optional_mode)
13fn parse_template_placeholders_with_mode(template: &str) -> Vec<(String, Option<String>)> {
14    let mut placeholders = Vec::new();
15    let mut seen_fields = std::collections::HashSet::new();
16
17    // First, find all {{ field:mode }} patterns
18    let mode_pattern = Regex::new(r"\{\{\s*(\w+)\s*:\s*(\w+)\s*\}\}").unwrap();
19    for cap in mode_pattern.captures_iter(template) {
20        let field_name = cap[1].to_string();
21        let mode = cap[2].to_string();
22        placeholders.push((field_name.clone(), Some(mode)));
23        seen_fields.insert(field_name);
24    }
25
26    // Then, find all standard {{ field }} patterns (without mode)
27    let standard_pattern = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
28    for cap in standard_pattern.captures_iter(template) {
29        let field_name = cap[1].to_string();
30        // Check if this field was already captured with a mode
31        if !seen_fields.contains(&field_name) {
32            placeholders.push((field_name, None));
33        }
34    }
35
36    placeholders
37}
38
39/// Extract doc comments from attributes
40fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
41    attrs
42        .iter()
43        .filter_map(|attr| {
44            if attr.path().is_ident("doc")
45                && let syn::Meta::NameValue(meta_name_value) = &attr.meta
46                && let syn::Expr::Lit(syn::ExprLit {
47                    lit: syn::Lit::Str(lit_str),
48                    ..
49                }) = &meta_name_value.value
50            {
51                return Some(lit_str.value());
52            }
53            None
54        })
55        .map(|s| s.trim().to_string())
56        .collect::<Vec<_>>()
57        .join(" ")
58}
59
60/// Generate example JSON representation for a struct
61fn generate_example_only_parts(
62    fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
63    has_default: bool,
64) -> proc_macro2::TokenStream {
65    let mut field_values = Vec::new();
66
67    for field in fields.iter() {
68        let field_name = field.ident.as_ref().unwrap();
69        let field_name_str = field_name.to_string();
70        let attrs = parse_field_prompt_attrs(&field.attrs);
71
72        // Skip if marked to skip
73        if attrs.skip {
74            continue;
75        }
76
77        // Check if field has example attribute
78        if let Some(example) = attrs.example {
79            // Use the provided example value
80            field_values.push(quote! {
81                json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
82            });
83        } else if has_default {
84            // Use Default value if available
85            field_values.push(quote! {
86                let default_value = serde_json::to_value(&default_instance.#field_name)
87                    .unwrap_or(serde_json::Value::Null);
88                json_obj.insert(#field_name_str.to_string(), default_value);
89            });
90        } else {
91            // Use self's actual value
92            field_values.push(quote! {
93                let value = serde_json::to_value(&self.#field_name)
94                    .unwrap_or(serde_json::Value::Null);
95                json_obj.insert(#field_name_str.to_string(), value);
96            });
97        }
98    }
99
100    if has_default {
101        quote! {
102            {
103                let default_instance = Self::default();
104                let mut json_obj = serde_json::Map::new();
105                #(#field_values)*
106                let json_value = serde_json::Value::Object(json_obj);
107                let json_str = serde_json::to_string_pretty(&json_value)
108                    .unwrap_or_else(|_| "{}".to_string());
109                vec![llm_toolkit::prompt::PromptPart::Text(json_str)]
110            }
111        }
112    } else {
113        quote! {
114            {
115                let mut json_obj = serde_json::Map::new();
116                #(#field_values)*
117                let json_value = serde_json::Value::Object(json_obj);
118                let json_str = serde_json::to_string_pretty(&json_value)
119                    .unwrap_or_else(|_| "{}".to_string());
120                vec![llm_toolkit::prompt::PromptPart::Text(json_str)]
121            }
122        }
123    }
124}
125
126/// Generate schema-only representation for a struct
127fn generate_schema_only_parts(
128    struct_name: &str,
129    struct_docs: &str,
130    fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
131) -> proc_macro2::TokenStream {
132    let mut schema_lines = vec![];
133
134    // Add header
135    if !struct_docs.is_empty() {
136        schema_lines.push(format!("### Schema for `{}`\n{}", struct_name, struct_docs));
137    } else {
138        schema_lines.push(format!("### Schema for `{}`", struct_name));
139    }
140
141    schema_lines.push("{".to_string());
142
143    // Process fields
144    for (i, field) in fields.iter().enumerate() {
145        let field_name = field.ident.as_ref().unwrap();
146        let attrs = parse_field_prompt_attrs(&field.attrs);
147
148        // Skip if marked to skip
149        if attrs.skip {
150            continue;
151        }
152
153        // Get field documentation
154        let field_docs = extract_doc_comments(&field.attrs);
155
156        // Determine the type representation
157        let type_str = format_type_for_schema(&field.ty);
158
159        // Build field line
160        let mut field_line = format!("  \"{}\": \"{}\"", field_name, type_str);
161
162        // Add comment if there's documentation
163        if !field_docs.is_empty() {
164            field_line.push_str(&format!(", // {}", field_docs));
165        }
166
167        // Add comma if not last field (accounting for skipped fields)
168        let remaining_fields = fields
169            .iter()
170            .skip(i + 1)
171            .filter(|f| {
172                let attrs = parse_field_prompt_attrs(&f.attrs);
173                !attrs.skip
174            })
175            .count();
176
177        if remaining_fields > 0 {
178            field_line.push(',');
179        }
180
181        schema_lines.push(field_line);
182    }
183
184    schema_lines.push("}".to_string());
185
186    let schema_str = schema_lines.join("\n");
187
188    quote! {
189        vec![llm_toolkit::prompt::PromptPart::Text(#schema_str.to_string())]
190    }
191}
192
193/// Format a type for schema representation
194fn format_type_for_schema(ty: &syn::Type) -> String {
195    // Simple type formatting - can be enhanced
196    match ty {
197        syn::Type::Path(type_path) => {
198            let path = &type_path.path;
199            if let Some(last_segment) = path.segments.last() {
200                let type_name = last_segment.ident.to_string();
201
202                // Handle Option<T>
203                if type_name == "Option"
204                    && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
205                    && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
206                {
207                    return format!("{} | null", format_type_for_schema(inner_type));
208                }
209
210                // Map common types
211                match type_name.as_str() {
212                    "String" | "str" => "string".to_string(),
213                    "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
214                    | "u64" | "u128" | "usize" => "number".to_string(),
215                    "f32" | "f64" => "number".to_string(),
216                    "bool" => "boolean".to_string(),
217                    "Vec" => {
218                        if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
219                            && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
220                        {
221                            return format!("{}[]", format_type_for_schema(inner_type));
222                        }
223                        "array".to_string()
224                    }
225                    _ => type_name.to_lowercase(),
226                }
227            } else {
228                "unknown".to_string()
229            }
230        }
231        _ => "unknown".to_string(),
232    }
233}
234
235/// Result of parsing prompt attribute
236enum PromptAttribute {
237    Skip,
238    Description(String),
239    None,
240}
241
242/// Parse #[prompt(...)] attribute on enum variant
243fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
244    for attr in attrs {
245        if attr.path().is_ident("prompt") {
246            // Check for #[prompt(skip)]
247            if let Ok(meta_list) = attr.meta.require_list() {
248                let tokens = &meta_list.tokens;
249                let tokens_str = tokens.to_string();
250                if tokens_str == "skip" {
251                    return PromptAttribute::Skip;
252                }
253            }
254
255            // Check for #[prompt("description")]
256            if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
257                return PromptAttribute::Description(lit_str.value());
258            }
259        }
260    }
261    PromptAttribute::None
262}
263
264/// Parsed field-level prompt attributes
265#[derive(Debug, Default)]
266struct FieldPromptAttrs {
267    skip: bool,
268    rename: Option<String>,
269    format_with: Option<String>,
270    image: bool,
271    example: Option<String>,
272}
273
274/// Parse #[prompt(...)] attributes for struct fields
275fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
276    let mut result = FieldPromptAttrs::default();
277
278    for attr in attrs {
279        if attr.path().is_ident("prompt") {
280            // Try to parse as meta list #[prompt(key = value, ...)]
281            if let Ok(meta_list) = attr.meta.require_list() {
282                // Parse the tokens inside the parentheses
283                if let Ok(metas) =
284                    meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
285                {
286                    for meta in metas {
287                        match meta {
288                            Meta::Path(path) if path.is_ident("skip") => {
289                                result.skip = true;
290                            }
291                            Meta::NameValue(nv) if nv.path.is_ident("rename") => {
292                                if let syn::Expr::Lit(syn::ExprLit {
293                                    lit: syn::Lit::Str(lit_str),
294                                    ..
295                                }) = nv.value
296                                {
297                                    result.rename = Some(lit_str.value());
298                                }
299                            }
300                            Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
301                                if let syn::Expr::Lit(syn::ExprLit {
302                                    lit: syn::Lit::Str(lit_str),
303                                    ..
304                                }) = nv.value
305                                {
306                                    result.format_with = Some(lit_str.value());
307                                }
308                            }
309                            Meta::Path(path) if path.is_ident("image") => {
310                                result.image = true;
311                            }
312                            Meta::NameValue(nv) if nv.path.is_ident("example") => {
313                                if let syn::Expr::Lit(syn::ExprLit {
314                                    lit: syn::Lit::Str(lit_str),
315                                    ..
316                                }) = nv.value
317                                {
318                                    result.example = Some(lit_str.value());
319                                }
320                            }
321                            _ => {}
322                        }
323                    }
324                } else if meta_list.tokens.to_string() == "skip" {
325                    // Handle simple #[prompt(skip)] case
326                    result.skip = true;
327                } else if meta_list.tokens.to_string() == "image" {
328                    // Handle simple #[prompt(image)] case
329                    result.image = true;
330                }
331            }
332        }
333    }
334
335    result
336}
337
338/// Derives the `ToPrompt` trait for a struct or enum.
339///
340/// This macro provides two main functionalities depending on the type.
341///
342/// ## For Structs
343///
344/// It can generate a prompt based on a template string or by creating a key-value representation of the struct's fields.
345///
346/// ### Template-based Prompt
347///
348/// Use the `#[prompt(template = "...")]` attribute to provide a `minijinja` template. The struct fields will be available as variables in the template. The struct must also derive `serde::Serialize`.
349///
350/// ```rust,ignore
351/// #[derive(ToPrompt, Serialize)]
352/// #[prompt(template = "User {{ name }} is a {{ role }}.")]
353/// struct UserProfile {
354///     name: &'static str,
355///     role: &'static str,
356/// }
357/// ```
358///
359/// ### Tip: Handling Special Characters in Templates
360///
361/// When using raw string literals (e.g., `r#"..."#`) for your templates, be aware of a potential parsing issue if your template content includes the `#` character. To avoid this, use a different number of `#` symbols for the raw string delimiter.
362///
363/// **Problematic Example:**
364/// ```rust,ignore
365/// // This might fail to parse correctly
366/// #[prompt(template = r#"{"color": "#FFFFFF"}"#)]
367/// struct Color { /* ... */ }
368/// ```
369///
370/// **Solution:**
371/// ```rust,ignore
372/// // Use r##"..."## to avoid ambiguity with the inner '#'
373/// #[prompt(template = r##"{"color": "#FFFFFF"}"##)]
374/// struct Color { /* ... */ }
375/// ```
376///
377/// ## For Enums
378///
379/// For enums, the macro generates a descriptive prompt based on doc comments and attributes, outlining the available variants. See the documentation on the `ToPrompt` trait for more details.
380#[proc_macro_derive(ToPrompt, attributes(prompt))]
381pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
382    let input = parse_macro_input!(input as DeriveInput);
383
384    // Check if this is a struct or enum
385    match &input.data {
386        Data::Enum(data_enum) => {
387            // For enums, generate prompt from doc comments
388            let enum_name = &input.ident;
389            let enum_docs = extract_doc_comments(&input.attrs);
390
391            let mut prompt_lines = Vec::new();
392
393            // Add enum description
394            if !enum_docs.is_empty() {
395                prompt_lines.push(format!("{}: {}", enum_name, enum_docs));
396            } else {
397                prompt_lines.push(format!("{}:", enum_name));
398            }
399            prompt_lines.push(String::new()); // Empty line
400            prompt_lines.push("Possible values:".to_string());
401
402            // Add each variant with its documentation based on priority
403            for variant in &data_enum.variants {
404                let variant_name = &variant.ident;
405
406                // Apply fallback logic with priority
407                match parse_prompt_attribute(&variant.attrs) {
408                    PromptAttribute::Skip => {
409                        // Skip this variant completely
410                        continue;
411                    }
412                    PromptAttribute::Description(desc) => {
413                        // Use custom description from #[prompt("...")]
414                        prompt_lines.push(format!("- {}: {}", variant_name, desc));
415                    }
416                    PromptAttribute::None => {
417                        // Fall back to doc comment or just variant name
418                        let variant_docs = extract_doc_comments(&variant.attrs);
419                        if !variant_docs.is_empty() {
420                            prompt_lines.push(format!("- {}: {}", variant_name, variant_docs));
421                        } else {
422                            prompt_lines.push(format!("- {}", variant_name));
423                        }
424                    }
425                }
426            }
427
428            let prompt_string = prompt_lines.join("\n");
429            let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
430
431            let expanded = quote! {
432                impl #impl_generics llm_toolkit::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
433                    fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
434                        vec![llm_toolkit::prompt::PromptPart::Text(#prompt_string.to_string())]
435                    }
436
437                    fn to_prompt(&self) -> String {
438                        #prompt_string.to_string()
439                    }
440                }
441            };
442
443            TokenStream::from(expanded)
444        }
445        Data::Struct(data_struct) => {
446            // Parse struct-level prompt attributes for template, template_file, mode, and validate
447            let mut template_attr = None;
448            let mut template_file_attr = None;
449            let mut mode_attr = None;
450            let mut validate_attr = false;
451
452            for attr in &input.attrs {
453                if attr.path().is_ident("prompt") {
454                    // Try to parse the attribute arguments
455                    if let Ok(metas) =
456                        attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
457                    {
458                        for meta in metas {
459                            match meta {
460                                Meta::NameValue(nv) if nv.path.is_ident("template") => {
461                                    if let syn::Expr::Lit(expr_lit) = nv.value
462                                        && let syn::Lit::Str(lit_str) = expr_lit.lit
463                                    {
464                                        template_attr = Some(lit_str.value());
465                                    }
466                                }
467                                Meta::NameValue(nv) if nv.path.is_ident("template_file") => {
468                                    if let syn::Expr::Lit(expr_lit) = nv.value
469                                        && let syn::Lit::Str(lit_str) = expr_lit.lit
470                                    {
471                                        template_file_attr = Some(lit_str.value());
472                                    }
473                                }
474                                Meta::NameValue(nv) if nv.path.is_ident("mode") => {
475                                    if let syn::Expr::Lit(expr_lit) = nv.value
476                                        && let syn::Lit::Str(lit_str) = expr_lit.lit
477                                    {
478                                        mode_attr = Some(lit_str.value());
479                                    }
480                                }
481                                Meta::NameValue(nv) if nv.path.is_ident("validate") => {
482                                    if let syn::Expr::Lit(expr_lit) = nv.value
483                                        && let syn::Lit::Bool(lit_bool) = expr_lit.lit
484                                    {
485                                        validate_attr = lit_bool.value();
486                                    }
487                                }
488                                _ => {}
489                            }
490                        }
491                    }
492                }
493            }
494
495            // Check for mutual exclusivity between template and template_file
496            if template_attr.is_some() && template_file_attr.is_some() {
497                return syn::Error::new(
498                    input.ident.span(),
499                    "The `template` and `template_file` attributes are mutually exclusive. Please use only one.",
500                ).to_compile_error().into();
501            }
502
503            // Load template from file if template_file is specified
504            let template_str = if let Some(file_path) = template_file_attr {
505                // Try multiple strategies to find the template file
506                // This is necessary to support both normal compilation and trybuild tests
507
508                let mut full_path = None;
509
510                // Strategy 1: Try relative to CARGO_MANIFEST_DIR (normal compilation)
511                if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
512                    // Check if this is a trybuild temporary directory
513                    let is_trybuild = manifest_dir.contains("target/tests/trybuild");
514
515                    if !is_trybuild {
516                        // Normal compilation - use CARGO_MANIFEST_DIR directly
517                        let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
518                        if candidate.exists() {
519                            full_path = Some(candidate);
520                        }
521                    } else {
522                        // For trybuild, we need to find the original source directory
523                        // The manifest_dir looks like: .../target/tests/trybuild/llm-toolkit-macros
524                        // We need to get back to the original llm-toolkit-macros source directory
525
526                        // Extract the workspace root from the path
527                        if let Some(target_pos) = manifest_dir.find("/target/tests/trybuild") {
528                            let workspace_root = &manifest_dir[..target_pos];
529                            // Now construct the path to the original llm-toolkit-macros source
530                            let original_macros_dir = std::path::Path::new(workspace_root)
531                                .join("crates")
532                                .join("llm-toolkit-macros");
533
534                            let candidate = original_macros_dir.join(&file_path);
535                            if candidate.exists() {
536                                full_path = Some(candidate);
537                            }
538                        }
539                    }
540                }
541
542                // Strategy 2: Try as an absolute path or relative to current directory
543                if full_path.is_none() {
544                    let candidate = std::path::Path::new(&file_path).to_path_buf();
545                    if candidate.exists() {
546                        full_path = Some(candidate);
547                    }
548                }
549
550                // Strategy 3: For trybuild tests - try to find the file by looking in parent directories
551                // This handles the case where trybuild creates a temporary project
552                if full_path.is_none()
553                    && let Ok(current_dir) = std::env::current_dir()
554                {
555                    let mut search_dir = current_dir.as_path();
556                    // Search up to 10 levels up
557                    for _ in 0..10 {
558                        // Try from the llm-toolkit-macros directory
559                        let macros_dir = search_dir.join("crates/llm-toolkit-macros");
560                        if macros_dir.exists() {
561                            let candidate = macros_dir.join(&file_path);
562                            if candidate.exists() {
563                                full_path = Some(candidate);
564                                break;
565                            }
566                        }
567                        // Try directly
568                        let candidate = search_dir.join(&file_path);
569                        if candidate.exists() {
570                            full_path = Some(candidate);
571                            break;
572                        }
573                        if let Some(parent) = search_dir.parent() {
574                            search_dir = parent;
575                        } else {
576                            break;
577                        }
578                    }
579                }
580
581                // If we still haven't found the file, use the original path for a better error message
582                let final_path =
583                    full_path.unwrap_or_else(|| std::path::Path::new(&file_path).to_path_buf());
584
585                // Read the file at compile time
586                match std::fs::read_to_string(&final_path) {
587                    Ok(content) => Some(content),
588                    Err(e) => {
589                        return syn::Error::new(
590                            input.ident.span(),
591                            format!(
592                                "Failed to read template file '{}': {}",
593                                final_path.display(),
594                                e
595                            ),
596                        )
597                        .to_compile_error()
598                        .into();
599                    }
600                }
601            } else {
602                template_attr
603            };
604
605            // Perform validation if requested
606            if validate_attr && let Some(template) = &template_str {
607                // Validate Jinja syntax
608                let mut env = minijinja::Environment::new();
609                if let Err(e) = env.add_template("validation", template) {
610                    // Generate a compile warning using deprecated const hack
611                    let warning_msg =
612                        format!("Template validation warning: Invalid Jinja syntax - {}", e);
613                    let warning_ident = syn::Ident::new(
614                        "TEMPLATE_VALIDATION_WARNING",
615                        proc_macro2::Span::call_site(),
616                    );
617                    let _warning_tokens = quote! {
618                        #[deprecated(note = #warning_msg)]
619                        const #warning_ident: () = ();
620                        let _ = #warning_ident;
621                    };
622                    // We'll inject this warning into the generated code
623                    eprintln!("cargo:warning={}", warning_msg);
624                }
625
626                // Extract variables from template and check against struct fields
627                let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
628                    &fields.named
629                } else {
630                    panic!("Template validation is only supported for structs with named fields.");
631                };
632
633                let field_names: std::collections::HashSet<String> = fields
634                    .iter()
635                    .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
636                    .collect();
637
638                // Parse template placeholders
639                let placeholders = parse_template_placeholders_with_mode(template);
640
641                for (placeholder_name, _mode) in &placeholders {
642                    if placeholder_name != "self" && !field_names.contains(placeholder_name) {
643                        let warning_msg = format!(
644                            "Template validation warning: Variable '{}' used in template but not found in struct fields",
645                            placeholder_name
646                        );
647                        eprintln!("cargo:warning={}", warning_msg);
648                    }
649                }
650            }
651
652            let name = input.ident;
653            let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
654
655            // Extract struct name and doc comment for use in schema generation
656            let struct_docs = extract_doc_comments(&input.attrs);
657
658            // Check if this is a mode-based struct (mode attribute present)
659            let is_mode_based =
660                mode_attr.is_some() || (template_str.is_none() && struct_docs.contains("mode"));
661
662            let expanded = if is_mode_based || mode_attr.is_some() {
663                // Mode-based generation: support schema_only, example_only, full
664                let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
665                    &fields.named
666                } else {
667                    panic!(
668                        "Mode-based prompt generation is only supported for structs with named fields."
669                    );
670                };
671
672                let struct_name_str = name.to_string();
673
674                // Check if struct derives Default
675                let has_default = input.attrs.iter().any(|attr| {
676                    if attr.path().is_ident("derive") {
677                        if let Ok(meta_list) = attr.meta.require_list() {
678                            let tokens_str = meta_list.tokens.to_string();
679                            tokens_str.contains("Default")
680                        } else {
681                            false
682                        }
683                    } else {
684                        false
685                    }
686                });
687
688                // Generate schema-only parts
689                let schema_parts =
690                    generate_schema_only_parts(&struct_name_str, &struct_docs, fields);
691
692                // Generate example parts
693                let example_parts = generate_example_only_parts(fields, has_default);
694
695                quote! {
696                    impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
697                        fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<llm_toolkit::prompt::PromptPart> {
698                            match mode {
699                                "schema_only" => #schema_parts,
700                                "example_only" => #example_parts,
701                                "full" | _ => {
702                                    // Combine schema and example
703                                    let mut parts = Vec::new();
704
705                                    // Add schema
706                                    let schema_parts = #schema_parts;
707                                    parts.extend(schema_parts);
708
709                                    // Add separator and example header
710                                    parts.push(llm_toolkit::prompt::PromptPart::Text("\n### Example".to_string()));
711                                    parts.push(llm_toolkit::prompt::PromptPart::Text(
712                                        format!("Here is an example of a valid `{}` object:", #struct_name_str)
713                                    ));
714
715                                    // Add example
716                                    let example_parts = #example_parts;
717                                    parts.extend(example_parts);
718
719                                    parts
720                                }
721                            }
722                        }
723
724                        fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
725                            self.to_prompt_parts_with_mode("full")
726                        }
727
728                        fn to_prompt(&self) -> String {
729                            self.to_prompt_parts()
730                                .into_iter()
731                                .filter_map(|part| match part {
732                                    llm_toolkit::prompt::PromptPart::Text(text) => Some(text),
733                                    _ => None,
734                                })
735                                .collect::<Vec<_>>()
736                                .join("\n")
737                        }
738                    }
739                }
740            } else if let Some(template) = template_str {
741                // Use template-based approach if template is provided
742                // Collect image fields separately for to_prompt_parts()
743                let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
744                    &fields.named
745                } else {
746                    panic!(
747                        "Template prompt generation is only supported for structs with named fields."
748                    );
749                };
750
751                // Parse template to detect mode syntax
752                let placeholders = parse_template_placeholders_with_mode(&template);
753                // Only use custom mode processing if template actually contains :mode syntax
754                let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
755                    mode.is_some()
756                        && fields
757                            .iter()
758                            .any(|f| f.ident.as_ref().unwrap() == field_name)
759                });
760
761                let mut image_field_parts = Vec::new();
762                for f in fields.iter() {
763                    let field_name = f.ident.as_ref().unwrap();
764                    let attrs = parse_field_prompt_attrs(&f.attrs);
765
766                    if attrs.image {
767                        // This field is marked as an image
768                        image_field_parts.push(quote! {
769                            parts.extend(self.#field_name.to_prompt_parts());
770                        });
771                    }
772                }
773
774                // Generate appropriate code based on whether mode syntax is used
775                if has_mode_syntax {
776                    // Build custom context for fields with mode specifications
777                    let mut context_fields = Vec::new();
778                    let mut modified_template = template.clone();
779
780                    // Process each placeholder with mode
781                    for (field_name, mode_opt) in &placeholders {
782                        if let Some(mode) = mode_opt {
783                            // Create a unique key for this field:mode combination
784                            let unique_key = format!("{}__{}", field_name, mode);
785
786                            // Replace {{ field:mode }} with {{ field__mode }} in template
787                            let pattern = format!("{{{{ {}:{} }}}}", field_name, mode);
788                            let replacement = format!("{{{{ {} }}}}", unique_key);
789                            modified_template = modified_template.replace(&pattern, &replacement);
790
791                            // Find the corresponding field
792                            let field_ident =
793                                syn::Ident::new(field_name, proc_macro2::Span::call_site());
794
795                            // Add to context with mode specification
796                            context_fields.push(quote! {
797                                context.insert(
798                                    #unique_key.to_string(),
799                                    minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
800                                );
801                            });
802                        }
803                    }
804
805                    // Add individual fields via direct access (for non-mode fields)
806                    for field in fields.iter() {
807                        let field_name = field.ident.as_ref().unwrap();
808                        let field_name_str = field_name.to_string();
809
810                        // Skip if this field already has a mode-specific entry
811                        let has_mode_entry = placeholders
812                            .iter()
813                            .any(|(name, mode)| name == &field_name_str && mode.is_some());
814
815                        if !has_mode_entry {
816                            // Check if field type is likely a struct that implements ToPrompt
817                            // (not a primitive type)
818                            let is_primitive = match &field.ty {
819                                syn::Type::Path(type_path) => {
820                                    if let Some(segment) = type_path.path.segments.last() {
821                                        let type_name = segment.ident.to_string();
822                                        matches!(
823                                            type_name.as_str(),
824                                            "String"
825                                                | "str"
826                                                | "i8"
827                                                | "i16"
828                                                | "i32"
829                                                | "i64"
830                                                | "i128"
831                                                | "isize"
832                                                | "u8"
833                                                | "u16"
834                                                | "u32"
835                                                | "u64"
836                                                | "u128"
837                                                | "usize"
838                                                | "f32"
839                                                | "f64"
840                                                | "bool"
841                                                | "char"
842                                        )
843                                    } else {
844                                        false
845                                    }
846                                }
847                                _ => false,
848                            };
849
850                            if is_primitive {
851                                context_fields.push(quote! {
852                                    context.insert(
853                                        #field_name_str.to_string(),
854                                        minijinja::Value::from_serialize(&self.#field_name)
855                                    );
856                                });
857                            } else {
858                                // For non-primitive types, use to_prompt()
859                                context_fields.push(quote! {
860                                    context.insert(
861                                        #field_name_str.to_string(),
862                                        minijinja::Value::from(self.#field_name.to_prompt())
863                                    );
864                                });
865                            }
866                        }
867                    }
868
869                    quote! {
870                        impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
871                            fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
872                                let mut parts = Vec::new();
873
874                                // Add image parts first
875                                #(#image_field_parts)*
876
877                                // Build custom context and render template
878                                let text = {
879                                    let mut env = minijinja::Environment::new();
880                                    env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
881                                        panic!("Failed to parse template: {}", e)
882                                    });
883
884                                    let tmpl = env.get_template("prompt").unwrap();
885
886                                    let mut context = std::collections::HashMap::new();
887                                    #(#context_fields)*
888
889                                    tmpl.render(context).unwrap_or_else(|e| {
890                                        format!("Failed to render prompt: {}", e)
891                                    })
892                                };
893
894                                if !text.is_empty() {
895                                    parts.push(llm_toolkit::prompt::PromptPart::Text(text));
896                                }
897
898                                parts
899                            }
900
901                            fn to_prompt(&self) -> String {
902                                // Same logic for to_prompt
903                                let mut env = minijinja::Environment::new();
904                                env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
905                                    panic!("Failed to parse template: {}", e)
906                                });
907
908                                let tmpl = env.get_template("prompt").unwrap();
909
910                                let mut context = std::collections::HashMap::new();
911                                #(#context_fields)*
912
913                                tmpl.render(context).unwrap_or_else(|e| {
914                                    format!("Failed to render prompt: {}", e)
915                                })
916                            }
917                        }
918                    }
919                } else {
920                    // No mode syntax, use direct template rendering with render_prompt
921                    quote! {
922                        impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
923                            fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
924                                let mut parts = Vec::new();
925
926                                // Add image parts first
927                                #(#image_field_parts)*
928
929                                // Add the rendered template as text
930                                let text = llm_toolkit::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
931                                    format!("Failed to render prompt: {}", e)
932                                });
933                                if !text.is_empty() {
934                                    parts.push(llm_toolkit::prompt::PromptPart::Text(text));
935                                }
936
937                                parts
938                            }
939
940                            fn to_prompt(&self) -> String {
941                                llm_toolkit::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
942                                    format!("Failed to render prompt: {}", e)
943                                })
944                            }
945                        }
946                    }
947                }
948            } else {
949                // Use default key-value format if no template is provided
950                // Now also generate to_prompt_parts() for multimodal support
951                let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
952                    &fields.named
953                } else {
954                    panic!(
955                        "Default prompt generation is only supported for structs with named fields."
956                    );
957                };
958
959                // Separate image fields from text fields
960                let mut text_field_parts = Vec::new();
961                let mut image_field_parts = Vec::new();
962
963                for f in fields.iter() {
964                    let field_name = f.ident.as_ref().unwrap();
965                    let attrs = parse_field_prompt_attrs(&f.attrs);
966
967                    // Skip if #[prompt(skip)] is present
968                    if attrs.skip {
969                        continue;
970                    }
971
972                    if attrs.image {
973                        // This field is marked as an image
974                        image_field_parts.push(quote! {
975                            parts.extend(self.#field_name.to_prompt_parts());
976                        });
977                    } else {
978                        // This is a regular text field
979                        // Determine the key based on priority:
980                        // 1. #[prompt(rename = "new_name")]
981                        // 2. Doc comment
982                        // 3. Field name (fallback)
983                        let key = if let Some(rename) = attrs.rename {
984                            rename
985                        } else {
986                            let doc_comment = extract_doc_comments(&f.attrs);
987                            if !doc_comment.is_empty() {
988                                doc_comment
989                            } else {
990                                field_name.to_string()
991                            }
992                        };
993
994                        // Determine the value based on format_with attribute
995                        let value_expr = if let Some(format_with) = attrs.format_with {
996                            // Parse the function path string into a syn::Path
997                            let func_path: syn::Path =
998                                syn::parse_str(&format_with).unwrap_or_else(|_| {
999                                    panic!("Invalid function path: {}", format_with)
1000                                });
1001                            quote! { #func_path(&self.#field_name) }
1002                        } else {
1003                            quote! { self.#field_name.to_prompt() }
1004                        };
1005
1006                        text_field_parts.push(quote! {
1007                            text_parts.push(format!("{}: {}", #key, #value_expr));
1008                        });
1009                    }
1010                }
1011
1012                // Generate the implementation with to_prompt_parts()
1013                quote! {
1014                    impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
1015                        fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
1016                            let mut parts = Vec::new();
1017
1018                            // Add image parts first
1019                            #(#image_field_parts)*
1020
1021                            // Collect text parts and add as a single text prompt part
1022                            let mut text_parts = Vec::new();
1023                            #(#text_field_parts)*
1024
1025                            if !text_parts.is_empty() {
1026                                parts.push(llm_toolkit::prompt::PromptPart::Text(text_parts.join("\n")));
1027                            }
1028
1029                            parts
1030                        }
1031
1032                        fn to_prompt(&self) -> String {
1033                            let mut text_parts = Vec::new();
1034                            #(#text_field_parts)*
1035                            text_parts.join("\n")
1036                        }
1037                    }
1038                }
1039            };
1040
1041            TokenStream::from(expanded)
1042        }
1043        Data::Union(_) => {
1044            panic!("`#[derive(ToPrompt)]` is not supported for unions");
1045        }
1046    }
1047}
1048
1049/// Information about a prompt target
1050#[derive(Debug, Clone)]
1051struct TargetInfo {
1052    name: String,
1053    template: Option<String>,
1054    field_configs: std::collections::HashMap<String, FieldTargetConfig>,
1055}
1056
1057/// Configuration for how a field should be handled for a specific target
1058#[derive(Debug, Clone, Default)]
1059struct FieldTargetConfig {
1060    skip: bool,
1061    rename: Option<String>,
1062    format_with: Option<String>,
1063    image: bool,
1064    include_only: bool, // true if this field is specifically included for this target
1065}
1066
1067/// Parse #[prompt_for(...)] attributes for ToPromptSet
1068fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
1069    let mut configs = Vec::new();
1070
1071    for attr in attrs {
1072        if attr.path().is_ident("prompt_for")
1073            && let Ok(meta_list) = attr.meta.require_list()
1074        {
1075            // Try to parse as meta list
1076            if meta_list.tokens.to_string() == "skip" {
1077                // Simple #[prompt_for(skip)] applies to all targets
1078                let config = FieldTargetConfig {
1079                    skip: true,
1080                    ..Default::default()
1081                };
1082                configs.push(("*".to_string(), config));
1083            } else if let Ok(metas) =
1084                meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1085            {
1086                let mut target_name = None;
1087                let mut config = FieldTargetConfig::default();
1088
1089                for meta in metas {
1090                    match meta {
1091                        Meta::NameValue(nv) if nv.path.is_ident("name") => {
1092                            if let syn::Expr::Lit(syn::ExprLit {
1093                                lit: syn::Lit::Str(lit_str),
1094                                ..
1095                            }) = nv.value
1096                            {
1097                                target_name = Some(lit_str.value());
1098                            }
1099                        }
1100                        Meta::Path(path) if path.is_ident("skip") => {
1101                            config.skip = true;
1102                        }
1103                        Meta::NameValue(nv) if nv.path.is_ident("rename") => {
1104                            if let syn::Expr::Lit(syn::ExprLit {
1105                                lit: syn::Lit::Str(lit_str),
1106                                ..
1107                            }) = nv.value
1108                            {
1109                                config.rename = Some(lit_str.value());
1110                            }
1111                        }
1112                        Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
1113                            if let syn::Expr::Lit(syn::ExprLit {
1114                                lit: syn::Lit::Str(lit_str),
1115                                ..
1116                            }) = nv.value
1117                            {
1118                                config.format_with = Some(lit_str.value());
1119                            }
1120                        }
1121                        Meta::Path(path) if path.is_ident("image") => {
1122                            config.image = true;
1123                        }
1124                        _ => {}
1125                    }
1126                }
1127
1128                if let Some(name) = target_name {
1129                    config.include_only = true;
1130                    configs.push((name, config));
1131                }
1132            }
1133        }
1134    }
1135
1136    configs
1137}
1138
1139/// Parse struct-level #[prompt_for(...)] attributes to find target templates
1140fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
1141    let mut targets = Vec::new();
1142
1143    for attr in attrs {
1144        if attr.path().is_ident("prompt_for")
1145            && let Ok(meta_list) = attr.meta.require_list()
1146            && let Ok(metas) =
1147                meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1148        {
1149            let mut target_name = None;
1150            let mut template = None;
1151
1152            for meta in metas {
1153                match meta {
1154                    Meta::NameValue(nv) if nv.path.is_ident("name") => {
1155                        if let syn::Expr::Lit(syn::ExprLit {
1156                            lit: syn::Lit::Str(lit_str),
1157                            ..
1158                        }) = nv.value
1159                        {
1160                            target_name = Some(lit_str.value());
1161                        }
1162                    }
1163                    Meta::NameValue(nv) if nv.path.is_ident("template") => {
1164                        if let syn::Expr::Lit(syn::ExprLit {
1165                            lit: syn::Lit::Str(lit_str),
1166                            ..
1167                        }) = nv.value
1168                        {
1169                            template = Some(lit_str.value());
1170                        }
1171                    }
1172                    _ => {}
1173                }
1174            }
1175
1176            if let Some(name) = target_name {
1177                targets.push(TargetInfo {
1178                    name,
1179                    template,
1180                    field_configs: std::collections::HashMap::new(),
1181                });
1182            }
1183        }
1184    }
1185
1186    targets
1187}
1188
1189#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
1190pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
1191    let input = parse_macro_input!(input as DeriveInput);
1192
1193    // Only support structs with named fields
1194    let data_struct = match &input.data {
1195        Data::Struct(data) => data,
1196        _ => {
1197            return syn::Error::new(
1198                input.ident.span(),
1199                "`#[derive(ToPromptSet)]` is only supported for structs",
1200            )
1201            .to_compile_error()
1202            .into();
1203        }
1204    };
1205
1206    let fields = match &data_struct.fields {
1207        syn::Fields::Named(fields) => &fields.named,
1208        _ => {
1209            return syn::Error::new(
1210                input.ident.span(),
1211                "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
1212            )
1213            .to_compile_error()
1214            .into();
1215        }
1216    };
1217
1218    // Parse struct-level attributes to find targets
1219    let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
1220
1221    // Parse field-level attributes
1222    for field in fields.iter() {
1223        let field_name = field.ident.as_ref().unwrap().to_string();
1224        let field_configs = parse_prompt_for_attrs(&field.attrs);
1225
1226        for (target_name, config) in field_configs {
1227            if target_name == "*" {
1228                // Apply to all targets
1229                for target in &mut targets {
1230                    target
1231                        .field_configs
1232                        .entry(field_name.clone())
1233                        .or_insert_with(FieldTargetConfig::default)
1234                        .skip = config.skip;
1235                }
1236            } else {
1237                // Find or create the target
1238                let target_exists = targets.iter().any(|t| t.name == target_name);
1239                if !target_exists {
1240                    // Add implicit target if not defined at struct level
1241                    targets.push(TargetInfo {
1242                        name: target_name.clone(),
1243                        template: None,
1244                        field_configs: std::collections::HashMap::new(),
1245                    });
1246                }
1247
1248                let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
1249
1250                target.field_configs.insert(field_name.clone(), config);
1251            }
1252        }
1253    }
1254
1255    // Generate match arms for each target
1256    let mut match_arms = Vec::new();
1257
1258    for target in &targets {
1259        let target_name = &target.name;
1260
1261        if let Some(template_str) = &target.template {
1262            // Template-based generation
1263            let mut image_parts = Vec::new();
1264
1265            for field in fields.iter() {
1266                let field_name = field.ident.as_ref().unwrap();
1267                let field_name_str = field_name.to_string();
1268
1269                if let Some(config) = target.field_configs.get(&field_name_str)
1270                    && config.image
1271                {
1272                    image_parts.push(quote! {
1273                        parts.extend(self.#field_name.to_prompt_parts());
1274                    });
1275                }
1276            }
1277
1278            match_arms.push(quote! {
1279                #target_name => {
1280                    let mut parts = Vec::new();
1281
1282                    #(#image_parts)*
1283
1284                    let text = llm_toolkit::prompt::render_prompt(#template_str, self)
1285                        .map_err(|e| llm_toolkit::prompt::PromptSetError::RenderFailed {
1286                            target: #target_name.to_string(),
1287                            source: e,
1288                        })?;
1289
1290                    if !text.is_empty() {
1291                        parts.push(llm_toolkit::prompt::PromptPart::Text(text));
1292                    }
1293
1294                    Ok(parts)
1295                }
1296            });
1297        } else {
1298            // Key-value based generation
1299            let mut text_field_parts = Vec::new();
1300            let mut image_field_parts = Vec::new();
1301
1302            for field in fields.iter() {
1303                let field_name = field.ident.as_ref().unwrap();
1304                let field_name_str = field_name.to_string();
1305
1306                // Check if field should be included for this target
1307                let config = target.field_configs.get(&field_name_str);
1308
1309                // Skip if explicitly marked to skip
1310                if let Some(cfg) = config
1311                    && cfg.skip
1312                {
1313                    continue;
1314                }
1315
1316                // For non-template targets, only include fields that are:
1317                // 1. Explicitly marked for this target with #[prompt_for(name = "Target")]
1318                // 2. Not marked for any specific target (default fields)
1319                let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
1320                let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
1321                    .iter()
1322                    .any(|(name, _)| name != "*");
1323
1324                if has_any_target_specific_config && !is_explicitly_for_this_target {
1325                    continue;
1326                }
1327
1328                if let Some(cfg) = config {
1329                    if cfg.image {
1330                        image_field_parts.push(quote! {
1331                            parts.extend(self.#field_name.to_prompt_parts());
1332                        });
1333                    } else {
1334                        let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
1335
1336                        let value_expr = if let Some(format_with) = &cfg.format_with {
1337                            // Parse the function path - if it fails, generate code that will produce a compile error
1338                            match syn::parse_str::<syn::Path>(format_with) {
1339                                Ok(func_path) => quote! { #func_path(&self.#field_name) },
1340                                Err(_) => {
1341                                    // Generate a compile error by using an invalid identifier
1342                                    let error_msg = format!(
1343                                        "Invalid function path in format_with: '{}'",
1344                                        format_with
1345                                    );
1346                                    quote! {
1347                                        compile_error!(#error_msg);
1348                                        String::new()
1349                                    }
1350                                }
1351                            }
1352                        } else {
1353                            quote! { self.#field_name.to_prompt() }
1354                        };
1355
1356                        text_field_parts.push(quote! {
1357                            text_parts.push(format!("{}: {}", #key, #value_expr));
1358                        });
1359                    }
1360                } else {
1361                    // Default handling for fields without specific config
1362                    text_field_parts.push(quote! {
1363                        text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
1364                    });
1365                }
1366            }
1367
1368            match_arms.push(quote! {
1369                #target_name => {
1370                    let mut parts = Vec::new();
1371
1372                    #(#image_field_parts)*
1373
1374                    let mut text_parts = Vec::new();
1375                    #(#text_field_parts)*
1376
1377                    if !text_parts.is_empty() {
1378                        parts.push(llm_toolkit::prompt::PromptPart::Text(text_parts.join("\n")));
1379                    }
1380
1381                    Ok(parts)
1382                }
1383            });
1384        }
1385    }
1386
1387    // Collect all target names for error reporting
1388    let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
1389
1390    // Add default case for unknown targets
1391    match_arms.push(quote! {
1392        _ => {
1393            let available = vec![#(#target_names.to_string()),*];
1394            Err(llm_toolkit::prompt::PromptSetError::TargetNotFound {
1395                target: target.to_string(),
1396                available,
1397            })
1398        }
1399    });
1400
1401    let struct_name = &input.ident;
1402    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1403
1404    let expanded = quote! {
1405        impl #impl_generics llm_toolkit::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
1406            fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<llm_toolkit::prompt::PromptPart>, llm_toolkit::prompt::PromptSetError> {
1407                match target {
1408                    #(#match_arms)*
1409                }
1410            }
1411        }
1412    };
1413
1414    TokenStream::from(expanded)
1415}
1416
1417/// Wrapper struct for parsing a comma-separated list of types
1418struct TypeList {
1419    types: Punctuated<syn::Type, Token![,]>,
1420}
1421
1422impl Parse for TypeList {
1423    fn parse(input: ParseStream) -> syn::Result<Self> {
1424        Ok(TypeList {
1425            types: Punctuated::parse_terminated(input)?,
1426        })
1427    }
1428}
1429
1430/// Generates a formatted Markdown examples section for the provided types.
1431///
1432/// This macro accepts a comma-separated list of types and generates a single
1433/// formatted Markdown string containing examples of each type.
1434///
1435/// # Example
1436///
1437/// ```rust,ignore
1438/// let examples = examples_section!(User, Concept);
1439/// // Produces a string like:
1440/// // ---
1441/// // ### Examples
1442/// //
1443/// // Here are examples of the data structures you should use.
1444/// //
1445/// // ---
1446/// // #### `User`
1447/// // {...json...}
1448/// // ---
1449/// // #### `Concept`
1450/// // {...json...}
1451/// // ---
1452/// ```
1453#[proc_macro]
1454pub fn examples_section(input: TokenStream) -> TokenStream {
1455    let input = parse_macro_input!(input as TypeList);
1456
1457    // Generate code for each type
1458    let mut type_sections = Vec::new();
1459
1460    for ty in input.types.iter() {
1461        // Extract the type name as a string
1462        let type_name_str = quote!(#ty).to_string();
1463
1464        // Generate the section for this type
1465        type_sections.push(quote! {
1466            {
1467                let type_name = #type_name_str;
1468                let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
1469                format!("---\n#### `{}`\n{}", type_name, json_example)
1470            }
1471        });
1472    }
1473
1474    // Build the complete examples string
1475    let expanded = quote! {
1476        {
1477            let mut sections = Vec::new();
1478            sections.push("---".to_string());
1479            sections.push("### Examples".to_string());
1480            sections.push("".to_string());
1481            sections.push("Here are examples of the data structures you should use.".to_string());
1482            sections.push("".to_string());
1483
1484            #(sections.push(#type_sections);)*
1485
1486            sections.push("---".to_string());
1487
1488            sections.join("\n")
1489        }
1490    };
1491
1492    TokenStream::from(expanded)
1493}
1494
1495/// Helper function to parse struct-level #[prompt_for(target = "...", template = "...")] attribute
1496fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
1497    for attr in attrs {
1498        if attr.path().is_ident("prompt_for")
1499            && let Ok(meta_list) = attr.meta.require_list()
1500            && let Ok(metas) =
1501                meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1502        {
1503            let mut target_type = None;
1504            let mut template = None;
1505
1506            for meta in metas {
1507                match meta {
1508                    Meta::NameValue(nv) if nv.path.is_ident("target") => {
1509                        if let syn::Expr::Lit(syn::ExprLit {
1510                            lit: syn::Lit::Str(lit_str),
1511                            ..
1512                        }) = nv.value
1513                        {
1514                            // Parse the type string into a syn::Type
1515                            target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
1516                        }
1517                    }
1518                    Meta::NameValue(nv) if nv.path.is_ident("template") => {
1519                        if let syn::Expr::Lit(syn::ExprLit {
1520                            lit: syn::Lit::Str(lit_str),
1521                            ..
1522                        }) = nv.value
1523                        {
1524                            template = Some(lit_str.value());
1525                        }
1526                    }
1527                    _ => {}
1528                }
1529            }
1530
1531            if let (Some(target), Some(tmpl)) = (target_type, template) {
1532                return (target, tmpl);
1533            }
1534        }
1535    }
1536
1537    panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
1538}
1539
1540/// A procedural attribute macro that generates prompt-building functions and extractor structs for intent enums.
1541///
1542/// This macro should be applied to an enum to generate:
1543/// 1. A prompt-building function that incorporates enum documentation
1544/// 2. An extractor struct that implements `IntentExtractor`
1545///
1546/// # Requirements
1547///
1548/// The enum must have an `#[intent(...)]` attribute with:
1549/// - `prompt`: The prompt template (supports Jinja-style variables)
1550/// - `extractor_tag`: The tag to use for extraction
1551///
1552/// # Example
1553///
1554/// ```rust,ignore
1555/// #[define_intent]
1556/// #[intent(
1557///     prompt = "Analyze the intent: {{ user_input }}",
1558///     extractor_tag = "intent"
1559/// )]
1560/// enum MyIntent {
1561///     /// Create a new item
1562///     Create,
1563///     /// Update an existing item
1564///     Update,
1565///     /// Delete an item
1566///     Delete,
1567/// }
1568/// ```
1569///
1570/// This will generate:
1571/// - `pub fn build_my_intent_prompt(user_input: &str) -> String`
1572/// - `pub struct MyIntentExtractor;` with `IntentExtractor<MyIntent>` implementation
1573#[proc_macro_attribute]
1574pub fn define_intent(_attr: TokenStream, item: TokenStream) -> TokenStream {
1575    let input = parse_macro_input!(item as DeriveInput);
1576
1577    // Verify this is an enum
1578    let enum_data = match &input.data {
1579        Data::Enum(data) => data,
1580        _ => {
1581            return syn::Error::new(
1582                input.ident.span(),
1583                "`#[define_intent]` can only be applied to enums",
1584            )
1585            .to_compile_error()
1586            .into();
1587        }
1588    };
1589
1590    // Parse the #[intent(...)] attribute
1591    let mut prompt_template = None;
1592    let mut extractor_tag = None;
1593
1594    for attr in &input.attrs {
1595        if attr.path().is_ident("intent")
1596            && let Ok(metas) =
1597                attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1598        {
1599            for meta in metas {
1600                match meta {
1601                    Meta::NameValue(nv) if nv.path.is_ident("prompt") => {
1602                        if let syn::Expr::Lit(syn::ExprLit {
1603                            lit: syn::Lit::Str(lit_str),
1604                            ..
1605                        }) = nv.value
1606                        {
1607                            prompt_template = Some(lit_str.value());
1608                        }
1609                    }
1610                    Meta::NameValue(nv) if nv.path.is_ident("extractor_tag") => {
1611                        if let syn::Expr::Lit(syn::ExprLit {
1612                            lit: syn::Lit::Str(lit_str),
1613                            ..
1614                        }) = nv.value
1615                        {
1616                            extractor_tag = Some(lit_str.value());
1617                        }
1618                    }
1619                    _ => {}
1620                }
1621            }
1622        }
1623    }
1624
1625    // Validate required attributes
1626    let prompt_template = match prompt_template {
1627        Some(p) => p,
1628        None => {
1629            return syn::Error::new(
1630                input.ident.span(),
1631                "`#[intent(...)]` attribute must include `prompt = \"...\"`",
1632            )
1633            .to_compile_error()
1634            .into();
1635        }
1636    };
1637
1638    let extractor_tag = match extractor_tag {
1639        Some(t) => t,
1640        None => {
1641            return syn::Error::new(
1642                input.ident.span(),
1643                "`#[intent(...)]` attribute must include `extractor_tag = \"...\"`",
1644            )
1645            .to_compile_error()
1646            .into();
1647        }
1648    };
1649
1650    // Generate the intents documentation
1651    let enum_name = &input.ident;
1652    let enum_docs = extract_doc_comments(&input.attrs);
1653
1654    let mut intents_doc_lines = Vec::new();
1655
1656    // Add enum description if present
1657    if !enum_docs.is_empty() {
1658        intents_doc_lines.push(format!("{}: {}", enum_name, enum_docs));
1659    } else {
1660        intents_doc_lines.push(format!("{}:", enum_name));
1661    }
1662    intents_doc_lines.push(String::new()); // Empty line
1663    intents_doc_lines.push("Possible values:".to_string());
1664
1665    // Add each variant with its documentation
1666    for variant in &enum_data.variants {
1667        let variant_name = &variant.ident;
1668        let variant_docs = extract_doc_comments(&variant.attrs);
1669
1670        if !variant_docs.is_empty() {
1671            intents_doc_lines.push(format!("- {}: {}", variant_name, variant_docs));
1672        } else {
1673            intents_doc_lines.push(format!("- {}", variant_name));
1674        }
1675    }
1676
1677    let intents_doc_str = intents_doc_lines.join("\n");
1678
1679    // Parse template variables (excluding intents_doc which we'll inject)
1680    let placeholders = parse_template_placeholders_with_mode(&prompt_template);
1681    let user_variables: Vec<String> = placeholders
1682        .iter()
1683        .filter_map(|(name, _)| {
1684            if name != "intents_doc" {
1685                Some(name.clone())
1686            } else {
1687                None
1688            }
1689        })
1690        .collect();
1691
1692    // Generate function name (snake_case)
1693    let enum_name_str = enum_name.to_string();
1694    let snake_case_name = to_snake_case(&enum_name_str);
1695    let function_name = syn::Ident::new(
1696        &format!("build_{}_prompt", snake_case_name),
1697        proc_macro2::Span::call_site(),
1698    );
1699
1700    // Generate function parameters (all &str for simplicity)
1701    let function_params: Vec<proc_macro2::TokenStream> = user_variables
1702        .iter()
1703        .map(|var| {
1704            let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
1705            quote! { #ident: &str }
1706        })
1707        .collect();
1708
1709    // Generate context insertions
1710    let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
1711        .iter()
1712        .map(|var| {
1713            let var_str = var.clone();
1714            let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
1715            quote! {
1716                __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
1717            }
1718        })
1719        .collect();
1720
1721    // Template is already in Jinja syntax, no conversion needed
1722    let converted_template = prompt_template.clone();
1723
1724    // Generate extractor struct name
1725    let extractor_name = syn::Ident::new(
1726        &format!("{}Extractor", enum_name),
1727        proc_macro2::Span::call_site(),
1728    );
1729
1730    // Filter out the #[intent(...)] attribute from the enum attributes
1731    let filtered_attrs: Vec<_> = input
1732        .attrs
1733        .iter()
1734        .filter(|attr| !attr.path().is_ident("intent"))
1735        .collect();
1736
1737    // Rebuild the enum with filtered attributes
1738    let vis = &input.vis;
1739    let generics = &input.generics;
1740    let variants = &enum_data.variants;
1741    let enum_output = quote! {
1742        #(#filtered_attrs)*
1743        #vis enum #enum_name #generics {
1744            #variants
1745        }
1746    };
1747
1748    // Generate the complete output
1749    let expanded = quote! {
1750        // Output the enum without the #[intent(...)] attribute
1751        #enum_output
1752
1753        // Generate the prompt-building function
1754        pub fn #function_name(#(#function_params),*) -> String {
1755            let mut env = minijinja::Environment::new();
1756            env.add_template("prompt", #converted_template)
1757                .expect("Failed to parse intent prompt template");
1758
1759            let tmpl = env.get_template("prompt").unwrap();
1760
1761            let mut __template_context = std::collections::HashMap::new();
1762
1763            // Add intents_doc
1764            __template_context.insert("intents_doc".to_string(), minijinja::Value::from(#intents_doc_str));
1765
1766            // Add user-provided variables
1767            #(#context_insertions)*
1768
1769            tmpl.render(&__template_context)
1770                .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
1771        }
1772
1773        // Generate the extractor struct
1774        pub struct #extractor_name;
1775
1776        impl #extractor_name {
1777            pub const EXTRACTOR_TAG: &'static str = #extractor_tag;
1778        }
1779
1780        impl llm_toolkit::intent::IntentExtractor<#enum_name> for #extractor_name {
1781            fn extract_intent(&self, response: &str) -> Result<#enum_name, llm_toolkit::intent::IntentExtractionError> {
1782                // Use the common extraction function with our tag
1783                llm_toolkit::intent::extract_intent_from_response(response, Self::EXTRACTOR_TAG)
1784            }
1785        }
1786    };
1787
1788    TokenStream::from(expanded)
1789}
1790
1791/// Convert PascalCase to snake_case
1792fn to_snake_case(s: &str) -> String {
1793    let mut result = String::new();
1794    let mut prev_upper = false;
1795
1796    for (i, ch) in s.chars().enumerate() {
1797        if ch.is_uppercase() {
1798            if i > 0 && !prev_upper {
1799                result.push('_');
1800            }
1801            result.push(ch.to_lowercase().next().unwrap());
1802            prev_upper = true;
1803        } else {
1804            result.push(ch);
1805            prev_upper = false;
1806        }
1807    }
1808
1809    result
1810}
1811
1812/// Derives the `ToPromptFor` trait for a struct
1813#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
1814pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
1815    let input = parse_macro_input!(input as DeriveInput);
1816
1817    // Parse the struct-level prompt_for attribute
1818    let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
1819
1820    let struct_name = &input.ident;
1821    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1822
1823    // Parse the template to find placeholders
1824    let placeholders = parse_template_placeholders_with_mode(&template);
1825
1826    // Convert template to minijinja syntax and build context generation code
1827    let mut converted_template = template.clone();
1828    let mut context_fields = Vec::new();
1829
1830    // Get struct fields for validation
1831    let fields = match &input.data {
1832        Data::Struct(data_struct) => match &data_struct.fields {
1833            syn::Fields::Named(fields) => &fields.named,
1834            _ => panic!("ToPromptFor is only supported for structs with named fields"),
1835        },
1836        _ => panic!("ToPromptFor is only supported for structs"),
1837    };
1838
1839    // Check if the struct has mode support (has #[prompt(mode = ...)] attribute)
1840    let has_mode_support = input.attrs.iter().any(|attr| {
1841        if attr.path().is_ident("prompt")
1842            && let Ok(metas) =
1843                attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1844        {
1845            for meta in metas {
1846                if let Meta::NameValue(nv) = meta
1847                    && nv.path.is_ident("mode")
1848                {
1849                    return true;
1850                }
1851            }
1852        }
1853        false
1854    });
1855
1856    // Process each placeholder
1857    for (placeholder_name, mode_opt) in &placeholders {
1858        if placeholder_name == "self" {
1859            if let Some(specific_mode) = mode_opt {
1860                // {self:some_mode} - use a unique key
1861                let unique_key = format!("self__{}", specific_mode);
1862
1863                // Replace {{ self:mode }} with {{ self__mode }} in template
1864                let pattern = format!("{{{{ self:{} }}}}", specific_mode);
1865                let replacement = format!("{{{{ {} }}}}", unique_key);
1866                converted_template = converted_template.replace(&pattern, &replacement);
1867
1868                // Add to context with the specific mode
1869                context_fields.push(quote! {
1870                    context.insert(
1871                        #unique_key.to_string(),
1872                        minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
1873                    );
1874                });
1875            } else {
1876                // {{self}} - already in correct format, no replacement needed
1877
1878                if has_mode_support {
1879                    // If the struct has mode support, use to_prompt_with_mode with the mode parameter
1880                    context_fields.push(quote! {
1881                        context.insert(
1882                            "self".to_string(),
1883                            minijinja::Value::from(self.to_prompt_with_mode(mode))
1884                        );
1885                    });
1886                } else {
1887                    // If the struct doesn't have mode support, use to_prompt() which gives key-value format
1888                    context_fields.push(quote! {
1889                        context.insert(
1890                            "self".to_string(),
1891                            minijinja::Value::from(self.to_prompt())
1892                        );
1893                    });
1894                }
1895            }
1896        } else {
1897            // It's a field placeholder
1898            // Check if the field exists
1899            let field_exists = fields.iter().any(|f| {
1900                f.ident
1901                    .as_ref()
1902                    .is_some_and(|ident| ident == placeholder_name)
1903            });
1904
1905            if field_exists {
1906                let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
1907
1908                // {{field}} - already in correct format, no replacement needed
1909
1910                // Add field to context - serialize the field value
1911                context_fields.push(quote! {
1912                    context.insert(
1913                        #placeholder_name.to_string(),
1914                        minijinja::Value::from_serialize(&self.#field_ident)
1915                    );
1916                });
1917            }
1918            // If field doesn't exist, we'll let minijinja handle the error at runtime
1919        }
1920    }
1921
1922    let expanded = quote! {
1923        impl #impl_generics llm_toolkit::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
1924        where
1925            #target_type: serde::Serialize,
1926        {
1927            fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
1928                // Create minijinja environment and add template
1929                let mut env = minijinja::Environment::new();
1930                env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
1931                    panic!("Failed to parse template: {}", e)
1932                });
1933
1934                let tmpl = env.get_template("prompt").unwrap();
1935
1936                // Build context
1937                let mut context = std::collections::HashMap::new();
1938                // Add self to the context for field access in templates
1939                context.insert(
1940                    "self".to_string(),
1941                    minijinja::Value::from_serialize(self)
1942                );
1943                // Add target to the context
1944                context.insert(
1945                    "target".to_string(),
1946                    minijinja::Value::from_serialize(target)
1947                );
1948                #(#context_fields)*
1949
1950                // Render template
1951                tmpl.render(context).unwrap_or_else(|e| {
1952                    format!("Failed to render prompt: {}", e)
1953                })
1954            }
1955        }
1956    };
1957
1958    TokenStream::from(expanded)
1959}