llm_toolkit_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use regex::Regex;
4use syn::{
5    Data, DeriveInput, Meta, Token,
6    parse::{Parse, ParseStream},
7    parse_macro_input,
8    punctuated::Punctuated,
9};
10
11/// Parse template placeholders using regex to find :mode patterns
12/// Returns a list of (field_name, optional_mode)
13fn parse_template_placeholders_with_mode(template: &str) -> Vec<(String, Option<String>)> {
14    let mut placeholders = Vec::new();
15    let mut seen_fields = std::collections::HashSet::new();
16
17    // First, find all {{ field:mode }} patterns
18    let mode_pattern = Regex::new(r"\{\{\s*(\w+)\s*:\s*(\w+)\s*\}\}").unwrap();
19    for cap in mode_pattern.captures_iter(template) {
20        let field_name = cap[1].to_string();
21        let mode = cap[2].to_string();
22        placeholders.push((field_name.clone(), Some(mode)));
23        seen_fields.insert(field_name);
24    }
25
26    // Then, find all standard {{ field }} patterns (without mode)
27    let standard_pattern = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
28    for cap in standard_pattern.captures_iter(template) {
29        let field_name = cap[1].to_string();
30        // Check if this field was already captured with a mode
31        if !seen_fields.contains(&field_name) {
32            placeholders.push((field_name, None));
33        }
34    }
35
36    placeholders
37}
38
39/// Extract doc comments from attributes
40fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
41    attrs
42        .iter()
43        .filter_map(|attr| {
44            if attr.path().is_ident("doc")
45                && let syn::Meta::NameValue(meta_name_value) = &attr.meta
46                && let syn::Expr::Lit(syn::ExprLit {
47                    lit: syn::Lit::Str(lit_str),
48                    ..
49                }) = &meta_name_value.value
50            {
51                return Some(lit_str.value());
52            }
53            None
54        })
55        .map(|s| s.trim().to_string())
56        .collect::<Vec<_>>()
57        .join(" ")
58}
59
60/// Generate example JSON representation for a struct
61fn generate_example_only_parts(
62    fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
63    has_default: bool,
64) -> proc_macro2::TokenStream {
65    let mut field_values = Vec::new();
66
67    for field in fields.iter() {
68        let field_name = field.ident.as_ref().unwrap();
69        let field_name_str = field_name.to_string();
70        let attrs = parse_field_prompt_attrs(&field.attrs);
71
72        // Skip if marked to skip
73        if attrs.skip {
74            continue;
75        }
76
77        // Check if field has example attribute
78        if let Some(example) = attrs.example {
79            // Use the provided example value
80            field_values.push(quote! {
81                json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
82            });
83        } else if has_default {
84            // Use Default value if available
85            field_values.push(quote! {
86                let default_value = serde_json::to_value(&default_instance.#field_name)
87                    .unwrap_or(serde_json::Value::Null);
88                json_obj.insert(#field_name_str.to_string(), default_value);
89            });
90        } else {
91            // Use self's actual value
92            field_values.push(quote! {
93                let value = serde_json::to_value(&self.#field_name)
94                    .unwrap_or(serde_json::Value::Null);
95                json_obj.insert(#field_name_str.to_string(), value);
96            });
97        }
98    }
99
100    if has_default {
101        quote! {
102            {
103                let default_instance = Self::default();
104                let mut json_obj = serde_json::Map::new();
105                #(#field_values)*
106                let json_value = serde_json::Value::Object(json_obj);
107                let json_str = serde_json::to_string_pretty(&json_value)
108                    .unwrap_or_else(|_| "{}".to_string());
109                vec![llm_toolkit::prompt::PromptPart::Text(json_str)]
110            }
111        }
112    } else {
113        quote! {
114            {
115                let mut json_obj = serde_json::Map::new();
116                #(#field_values)*
117                let json_value = serde_json::Value::Object(json_obj);
118                let json_str = serde_json::to_string_pretty(&json_value)
119                    .unwrap_or_else(|_| "{}".to_string());
120                vec![llm_toolkit::prompt::PromptPart::Text(json_str)]
121            }
122        }
123    }
124}
125
126/// Generate schema-only representation for a struct
127fn generate_schema_only_parts(
128    struct_name: &str,
129    struct_docs: &str,
130    fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
131) -> proc_macro2::TokenStream {
132    let mut schema_lines = vec![];
133
134    // Add header
135    if !struct_docs.is_empty() {
136        schema_lines.push(format!("### Schema for `{}`\n{}", struct_name, struct_docs));
137    } else {
138        schema_lines.push(format!("### Schema for `{}`", struct_name));
139    }
140
141    schema_lines.push("{".to_string());
142
143    // Process fields
144    for (i, field) in fields.iter().enumerate() {
145        let field_name = field.ident.as_ref().unwrap();
146        let attrs = parse_field_prompt_attrs(&field.attrs);
147
148        // Skip if marked to skip
149        if attrs.skip {
150            continue;
151        }
152
153        // Get field documentation
154        let field_docs = extract_doc_comments(&field.attrs);
155
156        // Determine the type representation
157        let type_str = format_type_for_schema(&field.ty);
158
159        // Build field line
160        let mut field_line = format!("  \"{}\": \"{}\"", field_name, type_str);
161
162        // Add comment if there's documentation
163        if !field_docs.is_empty() {
164            field_line.push_str(&format!(", // {}", field_docs));
165        }
166
167        // Add comma if not last field (accounting for skipped fields)
168        let remaining_fields = fields
169            .iter()
170            .skip(i + 1)
171            .filter(|f| {
172                let attrs = parse_field_prompt_attrs(&f.attrs);
173                !attrs.skip
174            })
175            .count();
176
177        if remaining_fields > 0 {
178            field_line.push(',');
179        }
180
181        schema_lines.push(field_line);
182    }
183
184    schema_lines.push("}".to_string());
185
186    let schema_str = schema_lines.join("\n");
187
188    quote! {
189        vec![llm_toolkit::prompt::PromptPart::Text(#schema_str.to_string())]
190    }
191}
192
193/// Format a type for schema representation
194fn format_type_for_schema(ty: &syn::Type) -> String {
195    // Simple type formatting - can be enhanced
196    match ty {
197        syn::Type::Path(type_path) => {
198            let path = &type_path.path;
199            if let Some(last_segment) = path.segments.last() {
200                let type_name = last_segment.ident.to_string();
201
202                // Handle Option<T>
203                if type_name == "Option"
204                    && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
205                    && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
206                {
207                    return format!("{} | null", format_type_for_schema(inner_type));
208                }
209
210                // Map common types
211                match type_name.as_str() {
212                    "String" | "str" => "string".to_string(),
213                    "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
214                    | "u64" | "u128" | "usize" => "number".to_string(),
215                    "f32" | "f64" => "number".to_string(),
216                    "bool" => "boolean".to_string(),
217                    "Vec" => {
218                        if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
219                            && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
220                        {
221                            return format!("{}[]", format_type_for_schema(inner_type));
222                        }
223                        "array".to_string()
224                    }
225                    _ => type_name.to_lowercase(),
226                }
227            } else {
228                "unknown".to_string()
229            }
230        }
231        _ => "unknown".to_string(),
232    }
233}
234
235/// Result of parsing prompt attribute
236enum PromptAttribute {
237    Skip,
238    Description(String),
239    None,
240}
241
242/// Parse #[prompt(...)] attribute on enum variant
243fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
244    for attr in attrs {
245        if attr.path().is_ident("prompt") {
246            // Check for #[prompt(skip)]
247            if let Ok(meta_list) = attr.meta.require_list() {
248                let tokens = &meta_list.tokens;
249                let tokens_str = tokens.to_string();
250                if tokens_str == "skip" {
251                    return PromptAttribute::Skip;
252                }
253            }
254
255            // Check for #[prompt("description")]
256            if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
257                return PromptAttribute::Description(lit_str.value());
258            }
259        }
260    }
261    PromptAttribute::None
262}
263
264/// Parsed field-level prompt attributes
265#[derive(Debug, Default)]
266struct FieldPromptAttrs {
267    skip: bool,
268    rename: Option<String>,
269    format_with: Option<String>,
270    image: bool,
271    example: Option<String>,
272}
273
274/// Parse #[prompt(...)] attributes for struct fields
275fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
276    let mut result = FieldPromptAttrs::default();
277
278    for attr in attrs {
279        if attr.path().is_ident("prompt") {
280            // Try to parse as meta list #[prompt(key = value, ...)]
281            if let Ok(meta_list) = attr.meta.require_list() {
282                // Parse the tokens inside the parentheses
283                if let Ok(metas) =
284                    meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
285                {
286                    for meta in metas {
287                        match meta {
288                            Meta::Path(path) if path.is_ident("skip") => {
289                                result.skip = true;
290                            }
291                            Meta::NameValue(nv) if nv.path.is_ident("rename") => {
292                                if let syn::Expr::Lit(syn::ExprLit {
293                                    lit: syn::Lit::Str(lit_str),
294                                    ..
295                                }) = nv.value
296                                {
297                                    result.rename = Some(lit_str.value());
298                                }
299                            }
300                            Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
301                                if let syn::Expr::Lit(syn::ExprLit {
302                                    lit: syn::Lit::Str(lit_str),
303                                    ..
304                                }) = nv.value
305                                {
306                                    result.format_with = Some(lit_str.value());
307                                }
308                            }
309                            Meta::Path(path) if path.is_ident("image") => {
310                                result.image = true;
311                            }
312                            Meta::NameValue(nv) if nv.path.is_ident("example") => {
313                                if let syn::Expr::Lit(syn::ExprLit {
314                                    lit: syn::Lit::Str(lit_str),
315                                    ..
316                                }) = nv.value
317                                {
318                                    result.example = Some(lit_str.value());
319                                }
320                            }
321                            _ => {}
322                        }
323                    }
324                } else if meta_list.tokens.to_string() == "skip" {
325                    // Handle simple #[prompt(skip)] case
326                    result.skip = true;
327                } else if meta_list.tokens.to_string() == "image" {
328                    // Handle simple #[prompt(image)] case
329                    result.image = true;
330                }
331            }
332        }
333    }
334
335    result
336}
337
338/// Derives the `ToPrompt` trait for a struct or enum.
339///
340/// This macro provides two main functionalities depending on the type.
341///
342/// ## For Structs
343///
344/// It can generate a prompt based on a template string or by creating a key-value representation of the struct's fields.
345///
346/// ### Template-based Prompt
347///
348/// Use the `#[prompt(template = "...")]` attribute to provide a `minijinja` template. The struct fields will be available as variables in the template. The struct must also derive `serde::Serialize`.
349///
350/// ```rust,ignore
351/// #[derive(ToPrompt, Serialize)]
352/// #[prompt(template = "User {{ name }} is a {{ role }}.")]
353/// struct UserProfile {
354///     name: &'static str,
355///     role: &'static str,
356/// }
357/// ```
358///
359/// ### Tip: Handling Special Characters in Templates
360///
361/// When using raw string literals (e.g., `r#"..."#`) for your templates, be aware of a potential parsing issue if your template content includes the `#` character. To avoid this, use a different number of `#` symbols for the raw string delimiter.
362///
363/// **Problematic Example:**
364/// ```rust,ignore
365/// // This might fail to parse correctly
366/// #[prompt(template = r#"{"color": "#FFFFFF"}"#)]
367/// struct Color { /* ... */ }
368/// ```
369///
370/// **Solution:**
371/// ```rust,ignore
372/// // Use r##"..."## to avoid ambiguity with the inner '#'
373/// #[prompt(template = r##"{"color": "#FFFFFF"}"##)]
374/// struct Color { /* ... */ }
375/// ```
376///
377/// ## For Enums
378///
379/// For enums, the macro generates a descriptive prompt based on doc comments and attributes, outlining the available variants. See the documentation on the `ToPrompt` trait for more details.
380#[proc_macro_derive(ToPrompt, attributes(prompt))]
381pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
382    let input = parse_macro_input!(input as DeriveInput);
383
384    // Check if this is a struct or enum
385    match &input.data {
386        Data::Enum(data_enum) => {
387            // For enums, generate prompt from doc comments
388            let enum_name = &input.ident;
389            let enum_docs = extract_doc_comments(&input.attrs);
390
391            let mut prompt_lines = Vec::new();
392
393            // Add enum description
394            if !enum_docs.is_empty() {
395                prompt_lines.push(format!("{}: {}", enum_name, enum_docs));
396            } else {
397                prompt_lines.push(format!("{}:", enum_name));
398            }
399            prompt_lines.push(String::new()); // Empty line
400            prompt_lines.push("Possible values:".to_string());
401
402            // Add each variant with its documentation based on priority
403            for variant in &data_enum.variants {
404                let variant_name = &variant.ident;
405
406                // Apply fallback logic with priority
407                match parse_prompt_attribute(&variant.attrs) {
408                    PromptAttribute::Skip => {
409                        // Skip this variant completely
410                        continue;
411                    }
412                    PromptAttribute::Description(desc) => {
413                        // Use custom description from #[prompt("...")]
414                        prompt_lines.push(format!("- {}: {}", variant_name, desc));
415                    }
416                    PromptAttribute::None => {
417                        // Fall back to doc comment or just variant name
418                        let variant_docs = extract_doc_comments(&variant.attrs);
419                        if !variant_docs.is_empty() {
420                            prompt_lines.push(format!("- {}: {}", variant_name, variant_docs));
421                        } else {
422                            prompt_lines.push(format!("- {}", variant_name));
423                        }
424                    }
425                }
426            }
427
428            let prompt_string = prompt_lines.join("\n");
429            let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
430
431            let expanded = quote! {
432                impl #impl_generics llm_toolkit::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
433                    fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
434                        vec![llm_toolkit::prompt::PromptPart::Text(#prompt_string.to_string())]
435                    }
436
437                    fn to_prompt(&self) -> String {
438                        #prompt_string.to_string()
439                    }
440                }
441            };
442
443            TokenStream::from(expanded)
444        }
445        Data::Struct(data_struct) => {
446            // Parse struct-level prompt attributes for template, template_file, mode, and validate
447            let mut template_attr = None;
448            let mut template_file_attr = None;
449            let mut mode_attr = None;
450            let mut validate_attr = false;
451
452            for attr in &input.attrs {
453                if attr.path().is_ident("prompt") {
454                    // Try to parse the attribute arguments
455                    if let Ok(metas) =
456                        attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
457                    {
458                        for meta in metas {
459                            match meta {
460                                Meta::NameValue(nv) if nv.path.is_ident("template") => {
461                                    if let syn::Expr::Lit(expr_lit) = nv.value
462                                        && let syn::Lit::Str(lit_str) = expr_lit.lit
463                                    {
464                                        template_attr = Some(lit_str.value());
465                                    }
466                                }
467                                Meta::NameValue(nv) if nv.path.is_ident("template_file") => {
468                                    if let syn::Expr::Lit(expr_lit) = nv.value
469                                        && let syn::Lit::Str(lit_str) = expr_lit.lit
470                                    {
471                                        template_file_attr = Some(lit_str.value());
472                                    }
473                                }
474                                Meta::NameValue(nv) if nv.path.is_ident("mode") => {
475                                    if let syn::Expr::Lit(expr_lit) = nv.value
476                                        && let syn::Lit::Str(lit_str) = expr_lit.lit
477                                    {
478                                        mode_attr = Some(lit_str.value());
479                                    }
480                                }
481                                Meta::NameValue(nv) if nv.path.is_ident("validate") => {
482                                    if let syn::Expr::Lit(expr_lit) = nv.value
483                                        && let syn::Lit::Bool(lit_bool) = expr_lit.lit
484                                    {
485                                        validate_attr = lit_bool.value();
486                                    }
487                                }
488                                _ => {}
489                            }
490                        }
491                    }
492                }
493            }
494
495            // Check for mutual exclusivity between template and template_file
496            if template_attr.is_some() && template_file_attr.is_some() {
497                return syn::Error::new(
498                    input.ident.span(),
499                    "The `template` and `template_file` attributes are mutually exclusive. Please use only one.",
500                ).to_compile_error().into();
501            }
502
503            // Load template from file if template_file is specified
504            let template_str = if let Some(file_path) = template_file_attr {
505                // Try multiple strategies to find the template file
506                // This is necessary to support both normal compilation and trybuild tests
507
508                let mut full_path = None;
509
510                // Strategy 1: Try relative to CARGO_MANIFEST_DIR (normal compilation)
511                if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
512                    // Check if this is a trybuild temporary directory
513                    let is_trybuild = manifest_dir.contains("target/tests/trybuild");
514
515                    if !is_trybuild {
516                        // Normal compilation - use CARGO_MANIFEST_DIR directly
517                        let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
518                        if candidate.exists() {
519                            full_path = Some(candidate);
520                        }
521                    } else {
522                        // For trybuild, we need to find the original source directory
523                        // The manifest_dir looks like: .../target/tests/trybuild/llm-toolkit-macros
524                        // We need to get back to the original llm-toolkit-macros source directory
525
526                        // Extract the workspace root from the path
527                        if let Some(target_pos) = manifest_dir.find("/target/tests/trybuild") {
528                            let workspace_root = &manifest_dir[..target_pos];
529                            // Now construct the path to the original llm-toolkit-macros source
530                            let original_macros_dir = std::path::Path::new(workspace_root)
531                                .join("crates")
532                                .join("llm-toolkit-macros");
533
534                            let candidate = original_macros_dir.join(&file_path);
535                            if candidate.exists() {
536                                full_path = Some(candidate);
537                            }
538                        }
539                    }
540                }
541
542                // Strategy 2: Try as an absolute path or relative to current directory
543                if full_path.is_none() {
544                    let candidate = std::path::Path::new(&file_path).to_path_buf();
545                    if candidate.exists() {
546                        full_path = Some(candidate);
547                    }
548                }
549
550                // Strategy 3: For trybuild tests - try to find the file by looking in parent directories
551                // This handles the case where trybuild creates a temporary project
552                if full_path.is_none()
553                    && let Ok(current_dir) = std::env::current_dir()
554                {
555                    let mut search_dir = current_dir.as_path();
556                    // Search up to 10 levels up
557                    for _ in 0..10 {
558                        // Try from the llm-toolkit-macros directory
559                        let macros_dir = search_dir.join("crates/llm-toolkit-macros");
560                        if macros_dir.exists() {
561                            let candidate = macros_dir.join(&file_path);
562                            if candidate.exists() {
563                                full_path = Some(candidate);
564                                break;
565                            }
566                        }
567                        // Try directly
568                        let candidate = search_dir.join(&file_path);
569                        if candidate.exists() {
570                            full_path = Some(candidate);
571                            break;
572                        }
573                        if let Some(parent) = search_dir.parent() {
574                            search_dir = parent;
575                        } else {
576                            break;
577                        }
578                    }
579                }
580
581                // If we still haven't found the file, use the original path for a better error message
582                let final_path =
583                    full_path.unwrap_or_else(|| std::path::Path::new(&file_path).to_path_buf());
584
585                // Read the file at compile time
586                match std::fs::read_to_string(&final_path) {
587                    Ok(content) => Some(content),
588                    Err(e) => {
589                        return syn::Error::new(
590                            input.ident.span(),
591                            format!(
592                                "Failed to read template file '{}': {}",
593                                final_path.display(),
594                                e
595                            ),
596                        )
597                        .to_compile_error()
598                        .into();
599                    }
600                }
601            } else {
602                template_attr
603            };
604
605            // Perform validation if requested
606            if validate_attr && let Some(template) = &template_str {
607                // Validate Jinja syntax
608                let mut env = minijinja::Environment::new();
609                if let Err(e) = env.add_template("validation", template) {
610                    // Generate a compile warning using deprecated const hack
611                    let warning_msg =
612                        format!("Template validation warning: Invalid Jinja syntax - {}", e);
613                    let warning_ident = syn::Ident::new(
614                        "TEMPLATE_VALIDATION_WARNING",
615                        proc_macro2::Span::call_site(),
616                    );
617                    let _warning_tokens = quote! {
618                        #[deprecated(note = #warning_msg)]
619                        const #warning_ident: () = ();
620                        let _ = #warning_ident;
621                    };
622                    // We'll inject this warning into the generated code
623                    eprintln!("cargo:warning={}", warning_msg);
624                }
625
626                // Extract variables from template and check against struct fields
627                let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
628                    &fields.named
629                } else {
630                    panic!("Template validation is only supported for structs with named fields.");
631                };
632
633                let field_names: std::collections::HashSet<String> = fields
634                    .iter()
635                    .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
636                    .collect();
637
638                // Parse template placeholders
639                let placeholders = parse_template_placeholders_with_mode(template);
640
641                for (placeholder_name, _mode) in &placeholders {
642                    if placeholder_name != "self" && !field_names.contains(placeholder_name) {
643                        let warning_msg = format!(
644                            "Template validation warning: Variable '{}' used in template but not found in struct fields",
645                            placeholder_name
646                        );
647                        eprintln!("cargo:warning={}", warning_msg);
648                    }
649                }
650            }
651
652            let name = input.ident;
653            let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
654
655            // Extract struct name and doc comment for use in schema generation
656            let struct_docs = extract_doc_comments(&input.attrs);
657
658            // Check if this is a mode-based struct (mode attribute present)
659            let is_mode_based =
660                mode_attr.is_some() || (template_str.is_none() && struct_docs.contains("mode"));
661
662            let expanded = if is_mode_based || mode_attr.is_some() {
663                // Mode-based generation: support schema_only, example_only, full
664                let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
665                    &fields.named
666                } else {
667                    panic!(
668                        "Mode-based prompt generation is only supported for structs with named fields."
669                    );
670                };
671
672                let struct_name_str = name.to_string();
673
674                // Check if struct derives Default
675                let has_default = input.attrs.iter().any(|attr| {
676                    if attr.path().is_ident("derive")
677                        && let Ok(meta_list) = attr.meta.require_list()
678                    {
679                        let tokens_str = meta_list.tokens.to_string();
680                        tokens_str.contains("Default")
681                    } else {
682                        false
683                    }
684                });
685
686                // Generate schema-only parts
687                let schema_parts =
688                    generate_schema_only_parts(&struct_name_str, &struct_docs, fields);
689
690                // Generate example parts
691                let example_parts = generate_example_only_parts(fields, has_default);
692
693                quote! {
694                    impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
695                        fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<llm_toolkit::prompt::PromptPart> {
696                            match mode {
697                                "schema_only" => #schema_parts,
698                                "example_only" => #example_parts,
699                                "full" | _ => {
700                                    // Combine schema and example
701                                    let mut parts = Vec::new();
702
703                                    // Add schema
704                                    let schema_parts = #schema_parts;
705                                    parts.extend(schema_parts);
706
707                                    // Add separator and example header
708                                    parts.push(llm_toolkit::prompt::PromptPart::Text("\n### Example".to_string()));
709                                    parts.push(llm_toolkit::prompt::PromptPart::Text(
710                                        format!("Here is an example of a valid `{}` object:", #struct_name_str)
711                                    ));
712
713                                    // Add example
714                                    let example_parts = #example_parts;
715                                    parts.extend(example_parts);
716
717                                    parts
718                                }
719                            }
720                        }
721
722                        fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
723                            self.to_prompt_parts_with_mode("full")
724                        }
725
726                        fn to_prompt(&self) -> String {
727                            self.to_prompt_parts()
728                                .into_iter()
729                                .filter_map(|part| match part {
730                                    llm_toolkit::prompt::PromptPart::Text(text) => Some(text),
731                                    _ => None,
732                                })
733                                .collect::<Vec<_>>()
734                                .join("\n")
735                        }
736                    }
737                }
738            } else if let Some(template) = template_str {
739                // Use template-based approach if template is provided
740                // Collect image fields separately for to_prompt_parts()
741                let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
742                    &fields.named
743                } else {
744                    panic!(
745                        "Template prompt generation is only supported for structs with named fields."
746                    );
747                };
748
749                // Parse template to detect mode syntax
750                let placeholders = parse_template_placeholders_with_mode(&template);
751                // Only use custom mode processing if template actually contains :mode syntax
752                let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
753                    mode.is_some()
754                        && fields
755                            .iter()
756                            .any(|f| f.ident.as_ref().unwrap() == field_name)
757                });
758
759                let mut image_field_parts = Vec::new();
760                for f in fields.iter() {
761                    let field_name = f.ident.as_ref().unwrap();
762                    let attrs = parse_field_prompt_attrs(&f.attrs);
763
764                    if attrs.image {
765                        // This field is marked as an image
766                        image_field_parts.push(quote! {
767                            parts.extend(self.#field_name.to_prompt_parts());
768                        });
769                    }
770                }
771
772                // Generate appropriate code based on whether mode syntax is used
773                if has_mode_syntax {
774                    // Build custom context for fields with mode specifications
775                    let mut context_fields = Vec::new();
776                    let mut modified_template = template.clone();
777
778                    // Process each placeholder with mode
779                    for (field_name, mode_opt) in &placeholders {
780                        if let Some(mode) = mode_opt {
781                            // Create a unique key for this field:mode combination
782                            let unique_key = format!("{}__{}", field_name, mode);
783
784                            // Replace {{ field:mode }} with {{ field__mode }} in template
785                            let pattern = format!("{{{{ {}:{} }}}}", field_name, mode);
786                            let replacement = format!("{{{{ {} }}}}", unique_key);
787                            modified_template = modified_template.replace(&pattern, &replacement);
788
789                            // Find the corresponding field
790                            let field_ident =
791                                syn::Ident::new(field_name, proc_macro2::Span::call_site());
792
793                            // Add to context with mode specification
794                            context_fields.push(quote! {
795                                context.insert(
796                                    #unique_key.to_string(),
797                                    minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
798                                );
799                            });
800                        }
801                    }
802
803                    // Add individual fields via direct access (for non-mode fields)
804                    for field in fields.iter() {
805                        let field_name = field.ident.as_ref().unwrap();
806                        let field_name_str = field_name.to_string();
807
808                        // Skip if this field already has a mode-specific entry
809                        let has_mode_entry = placeholders
810                            .iter()
811                            .any(|(name, mode)| name == &field_name_str && mode.is_some());
812
813                        if !has_mode_entry {
814                            // Check if field type is likely a struct that implements ToPrompt
815                            // (not a primitive type)
816                            let is_primitive = match &field.ty {
817                                syn::Type::Path(type_path) => {
818                                    if let Some(segment) = type_path.path.segments.last() {
819                                        let type_name = segment.ident.to_string();
820                                        matches!(
821                                            type_name.as_str(),
822                                            "String"
823                                                | "str"
824                                                | "i8"
825                                                | "i16"
826                                                | "i32"
827                                                | "i64"
828                                                | "i128"
829                                                | "isize"
830                                                | "u8"
831                                                | "u16"
832                                                | "u32"
833                                                | "u64"
834                                                | "u128"
835                                                | "usize"
836                                                | "f32"
837                                                | "f64"
838                                                | "bool"
839                                                | "char"
840                                        )
841                                    } else {
842                                        false
843                                    }
844                                }
845                                _ => false,
846                            };
847
848                            if is_primitive {
849                                context_fields.push(quote! {
850                                    context.insert(
851                                        #field_name_str.to_string(),
852                                        minijinja::Value::from_serialize(&self.#field_name)
853                                    );
854                                });
855                            } else {
856                                // For non-primitive types, use to_prompt()
857                                context_fields.push(quote! {
858                                    context.insert(
859                                        #field_name_str.to_string(),
860                                        minijinja::Value::from(self.#field_name.to_prompt())
861                                    );
862                                });
863                            }
864                        }
865                    }
866
867                    quote! {
868                        impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
869                            fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
870                                let mut parts = Vec::new();
871
872                                // Add image parts first
873                                #(#image_field_parts)*
874
875                                // Build custom context and render template
876                                let text = {
877                                    let mut env = minijinja::Environment::new();
878                                    env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
879                                        panic!("Failed to parse template: {}", e)
880                                    });
881
882                                    let tmpl = env.get_template("prompt").unwrap();
883
884                                    let mut context = std::collections::HashMap::new();
885                                    #(#context_fields)*
886
887                                    tmpl.render(context).unwrap_or_else(|e| {
888                                        format!("Failed to render prompt: {}", e)
889                                    })
890                                };
891
892                                if !text.is_empty() {
893                                    parts.push(llm_toolkit::prompt::PromptPart::Text(text));
894                                }
895
896                                parts
897                            }
898
899                            fn to_prompt(&self) -> String {
900                                // Same logic for to_prompt
901                                let mut env = minijinja::Environment::new();
902                                env.add_template("prompt", #modified_template).unwrap_or_else(|e| {
903                                    panic!("Failed to parse template: {}", e)
904                                });
905
906                                let tmpl = env.get_template("prompt").unwrap();
907
908                                let mut context = std::collections::HashMap::new();
909                                #(#context_fields)*
910
911                                tmpl.render(context).unwrap_or_else(|e| {
912                                    format!("Failed to render prompt: {}", e)
913                                })
914                            }
915                        }
916                    }
917                } else {
918                    // No mode syntax, use direct template rendering with render_prompt
919                    quote! {
920                        impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
921                            fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
922                                let mut parts = Vec::new();
923
924                                // Add image parts first
925                                #(#image_field_parts)*
926
927                                // Add the rendered template as text
928                                let text = llm_toolkit::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
929                                    format!("Failed to render prompt: {}", e)
930                                });
931                                if !text.is_empty() {
932                                    parts.push(llm_toolkit::prompt::PromptPart::Text(text));
933                                }
934
935                                parts
936                            }
937
938                            fn to_prompt(&self) -> String {
939                                llm_toolkit::prompt::render_prompt(#template, self).unwrap_or_else(|e| {
940                                    format!("Failed to render prompt: {}", e)
941                                })
942                            }
943                        }
944                    }
945                }
946            } else {
947                // Use default key-value format if no template is provided
948                // Now also generate to_prompt_parts() for multimodal support
949                let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
950                    &fields.named
951                } else {
952                    panic!(
953                        "Default prompt generation is only supported for structs with named fields."
954                    );
955                };
956
957                // Separate image fields from text fields
958                let mut text_field_parts = Vec::new();
959                let mut image_field_parts = Vec::new();
960
961                for f in fields.iter() {
962                    let field_name = f.ident.as_ref().unwrap();
963                    let attrs = parse_field_prompt_attrs(&f.attrs);
964
965                    // Skip if #[prompt(skip)] is present
966                    if attrs.skip {
967                        continue;
968                    }
969
970                    if attrs.image {
971                        // This field is marked as an image
972                        image_field_parts.push(quote! {
973                            parts.extend(self.#field_name.to_prompt_parts());
974                        });
975                    } else {
976                        // This is a regular text field
977                        // Determine the key based on priority:
978                        // 1. #[prompt(rename = "new_name")]
979                        // 2. Doc comment
980                        // 3. Field name (fallback)
981                        let key = if let Some(rename) = attrs.rename {
982                            rename
983                        } else {
984                            let doc_comment = extract_doc_comments(&f.attrs);
985                            if !doc_comment.is_empty() {
986                                doc_comment
987                            } else {
988                                field_name.to_string()
989                            }
990                        };
991
992                        // Determine the value based on format_with attribute
993                        let value_expr = if let Some(format_with) = attrs.format_with {
994                            // Parse the function path string into a syn::Path
995                            let func_path: syn::Path =
996                                syn::parse_str(&format_with).unwrap_or_else(|_| {
997                                    panic!("Invalid function path: {}", format_with)
998                                });
999                            quote! { #func_path(&self.#field_name) }
1000                        } else {
1001                            quote! { self.#field_name.to_prompt() }
1002                        };
1003
1004                        text_field_parts.push(quote! {
1005                            text_parts.push(format!("{}: {}", #key, #value_expr));
1006                        });
1007                    }
1008                }
1009
1010                // Generate the implementation with to_prompt_parts()
1011                quote! {
1012                    impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
1013                        fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
1014                            let mut parts = Vec::new();
1015
1016                            // Add image parts first
1017                            #(#image_field_parts)*
1018
1019                            // Collect text parts and add as a single text prompt part
1020                            let mut text_parts = Vec::new();
1021                            #(#text_field_parts)*
1022
1023                            if !text_parts.is_empty() {
1024                                parts.push(llm_toolkit::prompt::PromptPart::Text(text_parts.join("\n")));
1025                            }
1026
1027                            parts
1028                        }
1029
1030                        fn to_prompt(&self) -> String {
1031                            let mut text_parts = Vec::new();
1032                            #(#text_field_parts)*
1033                            text_parts.join("\n")
1034                        }
1035                    }
1036                }
1037            };
1038
1039            TokenStream::from(expanded)
1040        }
1041        Data::Union(_) => {
1042            panic!("`#[derive(ToPrompt)]` is not supported for unions");
1043        }
1044    }
1045}
1046
1047/// Information about a prompt target
1048#[derive(Debug, Clone)]
1049struct TargetInfo {
1050    name: String,
1051    template: Option<String>,
1052    field_configs: std::collections::HashMap<String, FieldTargetConfig>,
1053}
1054
1055/// Configuration for how a field should be handled for a specific target
1056#[derive(Debug, Clone, Default)]
1057struct FieldTargetConfig {
1058    skip: bool,
1059    rename: Option<String>,
1060    format_with: Option<String>,
1061    image: bool,
1062    include_only: bool, // true if this field is specifically included for this target
1063}
1064
1065/// Parse #[prompt_for(...)] attributes for ToPromptSet
1066fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
1067    let mut configs = Vec::new();
1068
1069    for attr in attrs {
1070        if attr.path().is_ident("prompt_for")
1071            && let Ok(meta_list) = attr.meta.require_list()
1072        {
1073            // Try to parse as meta list
1074            if meta_list.tokens.to_string() == "skip" {
1075                // Simple #[prompt_for(skip)] applies to all targets
1076                let config = FieldTargetConfig {
1077                    skip: true,
1078                    ..Default::default()
1079                };
1080                configs.push(("*".to_string(), config));
1081            } else if let Ok(metas) =
1082                meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1083            {
1084                let mut target_name = None;
1085                let mut config = FieldTargetConfig::default();
1086
1087                for meta in metas {
1088                    match meta {
1089                        Meta::NameValue(nv) if nv.path.is_ident("name") => {
1090                            if let syn::Expr::Lit(syn::ExprLit {
1091                                lit: syn::Lit::Str(lit_str),
1092                                ..
1093                            }) = nv.value
1094                            {
1095                                target_name = Some(lit_str.value());
1096                            }
1097                        }
1098                        Meta::Path(path) if path.is_ident("skip") => {
1099                            config.skip = true;
1100                        }
1101                        Meta::NameValue(nv) if nv.path.is_ident("rename") => {
1102                            if let syn::Expr::Lit(syn::ExprLit {
1103                                lit: syn::Lit::Str(lit_str),
1104                                ..
1105                            }) = nv.value
1106                            {
1107                                config.rename = Some(lit_str.value());
1108                            }
1109                        }
1110                        Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
1111                            if let syn::Expr::Lit(syn::ExprLit {
1112                                lit: syn::Lit::Str(lit_str),
1113                                ..
1114                            }) = nv.value
1115                            {
1116                                config.format_with = Some(lit_str.value());
1117                            }
1118                        }
1119                        Meta::Path(path) if path.is_ident("image") => {
1120                            config.image = true;
1121                        }
1122                        _ => {}
1123                    }
1124                }
1125
1126                if let Some(name) = target_name {
1127                    config.include_only = true;
1128                    configs.push((name, config));
1129                }
1130            }
1131        }
1132    }
1133
1134    configs
1135}
1136
1137/// Parse struct-level #[prompt_for(...)] attributes to find target templates
1138fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
1139    let mut targets = Vec::new();
1140
1141    for attr in attrs {
1142        if attr.path().is_ident("prompt_for")
1143            && let Ok(meta_list) = attr.meta.require_list()
1144            && let Ok(metas) =
1145                meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1146        {
1147            let mut target_name = None;
1148            let mut template = None;
1149
1150            for meta in metas {
1151                match meta {
1152                    Meta::NameValue(nv) if nv.path.is_ident("name") => {
1153                        if let syn::Expr::Lit(syn::ExprLit {
1154                            lit: syn::Lit::Str(lit_str),
1155                            ..
1156                        }) = nv.value
1157                        {
1158                            target_name = Some(lit_str.value());
1159                        }
1160                    }
1161                    Meta::NameValue(nv) if nv.path.is_ident("template") => {
1162                        if let syn::Expr::Lit(syn::ExprLit {
1163                            lit: syn::Lit::Str(lit_str),
1164                            ..
1165                        }) = nv.value
1166                        {
1167                            template = Some(lit_str.value());
1168                        }
1169                    }
1170                    _ => {}
1171                }
1172            }
1173
1174            if let Some(name) = target_name {
1175                targets.push(TargetInfo {
1176                    name,
1177                    template,
1178                    field_configs: std::collections::HashMap::new(),
1179                });
1180            }
1181        }
1182    }
1183
1184    targets
1185}
1186
1187#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
1188pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
1189    let input = parse_macro_input!(input as DeriveInput);
1190
1191    // Only support structs with named fields
1192    let data_struct = match &input.data {
1193        Data::Struct(data) => data,
1194        _ => {
1195            return syn::Error::new(
1196                input.ident.span(),
1197                "`#[derive(ToPromptSet)]` is only supported for structs",
1198            )
1199            .to_compile_error()
1200            .into();
1201        }
1202    };
1203
1204    let fields = match &data_struct.fields {
1205        syn::Fields::Named(fields) => &fields.named,
1206        _ => {
1207            return syn::Error::new(
1208                input.ident.span(),
1209                "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
1210            )
1211            .to_compile_error()
1212            .into();
1213        }
1214    };
1215
1216    // Parse struct-level attributes to find targets
1217    let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
1218
1219    // Parse field-level attributes
1220    for field in fields.iter() {
1221        let field_name = field.ident.as_ref().unwrap().to_string();
1222        let field_configs = parse_prompt_for_attrs(&field.attrs);
1223
1224        for (target_name, config) in field_configs {
1225            if target_name == "*" {
1226                // Apply to all targets
1227                for target in &mut targets {
1228                    target
1229                        .field_configs
1230                        .entry(field_name.clone())
1231                        .or_insert_with(FieldTargetConfig::default)
1232                        .skip = config.skip;
1233                }
1234            } else {
1235                // Find or create the target
1236                let target_exists = targets.iter().any(|t| t.name == target_name);
1237                if !target_exists {
1238                    // Add implicit target if not defined at struct level
1239                    targets.push(TargetInfo {
1240                        name: target_name.clone(),
1241                        template: None,
1242                        field_configs: std::collections::HashMap::new(),
1243                    });
1244                }
1245
1246                let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
1247
1248                target.field_configs.insert(field_name.clone(), config);
1249            }
1250        }
1251    }
1252
1253    // Generate match arms for each target
1254    let mut match_arms = Vec::new();
1255
1256    for target in &targets {
1257        let target_name = &target.name;
1258
1259        if let Some(template_str) = &target.template {
1260            // Template-based generation
1261            let mut image_parts = Vec::new();
1262
1263            for field in fields.iter() {
1264                let field_name = field.ident.as_ref().unwrap();
1265                let field_name_str = field_name.to_string();
1266
1267                if let Some(config) = target.field_configs.get(&field_name_str)
1268                    && config.image
1269                {
1270                    image_parts.push(quote! {
1271                        parts.extend(self.#field_name.to_prompt_parts());
1272                    });
1273                }
1274            }
1275
1276            match_arms.push(quote! {
1277                #target_name => {
1278                    let mut parts = Vec::new();
1279
1280                    #(#image_parts)*
1281
1282                    let text = llm_toolkit::prompt::render_prompt(#template_str, self)
1283                        .map_err(|e| llm_toolkit::prompt::PromptSetError::RenderFailed {
1284                            target: #target_name.to_string(),
1285                            source: e,
1286                        })?;
1287
1288                    if !text.is_empty() {
1289                        parts.push(llm_toolkit::prompt::PromptPart::Text(text));
1290                    }
1291
1292                    Ok(parts)
1293                }
1294            });
1295        } else {
1296            // Key-value based generation
1297            let mut text_field_parts = Vec::new();
1298            let mut image_field_parts = Vec::new();
1299
1300            for field in fields.iter() {
1301                let field_name = field.ident.as_ref().unwrap();
1302                let field_name_str = field_name.to_string();
1303
1304                // Check if field should be included for this target
1305                let config = target.field_configs.get(&field_name_str);
1306
1307                // Skip if explicitly marked to skip
1308                if let Some(cfg) = config
1309                    && cfg.skip
1310                {
1311                    continue;
1312                }
1313
1314                // For non-template targets, only include fields that are:
1315                // 1. Explicitly marked for this target with #[prompt_for(name = "Target")]
1316                // 2. Not marked for any specific target (default fields)
1317                let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
1318                let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
1319                    .iter()
1320                    .any(|(name, _)| name != "*");
1321
1322                if has_any_target_specific_config && !is_explicitly_for_this_target {
1323                    continue;
1324                }
1325
1326                if let Some(cfg) = config {
1327                    if cfg.image {
1328                        image_field_parts.push(quote! {
1329                            parts.extend(self.#field_name.to_prompt_parts());
1330                        });
1331                    } else {
1332                        let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
1333
1334                        let value_expr = if let Some(format_with) = &cfg.format_with {
1335                            // Parse the function path - if it fails, generate code that will produce a compile error
1336                            match syn::parse_str::<syn::Path>(format_with) {
1337                                Ok(func_path) => quote! { #func_path(&self.#field_name) },
1338                                Err(_) => {
1339                                    // Generate a compile error by using an invalid identifier
1340                                    let error_msg = format!(
1341                                        "Invalid function path in format_with: '{}'",
1342                                        format_with
1343                                    );
1344                                    quote! {
1345                                        compile_error!(#error_msg);
1346                                        String::new()
1347                                    }
1348                                }
1349                            }
1350                        } else {
1351                            quote! { self.#field_name.to_prompt() }
1352                        };
1353
1354                        text_field_parts.push(quote! {
1355                            text_parts.push(format!("{}: {}", #key, #value_expr));
1356                        });
1357                    }
1358                } else {
1359                    // Default handling for fields without specific config
1360                    text_field_parts.push(quote! {
1361                        text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
1362                    });
1363                }
1364            }
1365
1366            match_arms.push(quote! {
1367                #target_name => {
1368                    let mut parts = Vec::new();
1369
1370                    #(#image_field_parts)*
1371
1372                    let mut text_parts = Vec::new();
1373                    #(#text_field_parts)*
1374
1375                    if !text_parts.is_empty() {
1376                        parts.push(llm_toolkit::prompt::PromptPart::Text(text_parts.join("\n")));
1377                    }
1378
1379                    Ok(parts)
1380                }
1381            });
1382        }
1383    }
1384
1385    // Collect all target names for error reporting
1386    let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
1387
1388    // Add default case for unknown targets
1389    match_arms.push(quote! {
1390        _ => {
1391            let available = vec![#(#target_names.to_string()),*];
1392            Err(llm_toolkit::prompt::PromptSetError::TargetNotFound {
1393                target: target.to_string(),
1394                available,
1395            })
1396        }
1397    });
1398
1399    let struct_name = &input.ident;
1400    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1401
1402    let expanded = quote! {
1403        impl #impl_generics llm_toolkit::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
1404            fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<llm_toolkit::prompt::PromptPart>, llm_toolkit::prompt::PromptSetError> {
1405                match target {
1406                    #(#match_arms)*
1407                }
1408            }
1409        }
1410    };
1411
1412    TokenStream::from(expanded)
1413}
1414
1415/// Wrapper struct for parsing a comma-separated list of types
1416struct TypeList {
1417    types: Punctuated<syn::Type, Token![,]>,
1418}
1419
1420impl Parse for TypeList {
1421    fn parse(input: ParseStream) -> syn::Result<Self> {
1422        Ok(TypeList {
1423            types: Punctuated::parse_terminated(input)?,
1424        })
1425    }
1426}
1427
1428/// Generates a formatted Markdown examples section for the provided types.
1429///
1430/// This macro accepts a comma-separated list of types and generates a single
1431/// formatted Markdown string containing examples of each type.
1432///
1433/// # Example
1434///
1435/// ```rust,ignore
1436/// let examples = examples_section!(User, Concept);
1437/// // Produces a string like:
1438/// // ---
1439/// // ### Examples
1440/// //
1441/// // Here are examples of the data structures you should use.
1442/// //
1443/// // ---
1444/// // #### `User`
1445/// // {...json...}
1446/// // ---
1447/// // #### `Concept`
1448/// // {...json...}
1449/// // ---
1450/// ```
1451#[proc_macro]
1452pub fn examples_section(input: TokenStream) -> TokenStream {
1453    let input = parse_macro_input!(input as TypeList);
1454
1455    // Generate code for each type
1456    let mut type_sections = Vec::new();
1457
1458    for ty in input.types.iter() {
1459        // Extract the type name as a string
1460        let type_name_str = quote!(#ty).to_string();
1461
1462        // Generate the section for this type
1463        type_sections.push(quote! {
1464            {
1465                let type_name = #type_name_str;
1466                let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
1467                format!("---\n#### `{}`\n{}", type_name, json_example)
1468            }
1469        });
1470    }
1471
1472    // Build the complete examples string
1473    let expanded = quote! {
1474        {
1475            let mut sections = Vec::new();
1476            sections.push("---".to_string());
1477            sections.push("### Examples".to_string());
1478            sections.push("".to_string());
1479            sections.push("Here are examples of the data structures you should use.".to_string());
1480            sections.push("".to_string());
1481
1482            #(sections.push(#type_sections);)*
1483
1484            sections.push("---".to_string());
1485
1486            sections.join("\n")
1487        }
1488    };
1489
1490    TokenStream::from(expanded)
1491}
1492
1493/// Helper function to parse struct-level #[prompt_for(target = "...", template = "...")] attribute
1494fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
1495    for attr in attrs {
1496        if attr.path().is_ident("prompt_for")
1497            && let Ok(meta_list) = attr.meta.require_list()
1498            && let Ok(metas) =
1499                meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1500        {
1501            let mut target_type = None;
1502            let mut template = None;
1503
1504            for meta in metas {
1505                match meta {
1506                    Meta::NameValue(nv) if nv.path.is_ident("target") => {
1507                        if let syn::Expr::Lit(syn::ExprLit {
1508                            lit: syn::Lit::Str(lit_str),
1509                            ..
1510                        }) = nv.value
1511                        {
1512                            // Parse the type string into a syn::Type
1513                            target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
1514                        }
1515                    }
1516                    Meta::NameValue(nv) if nv.path.is_ident("template") => {
1517                        if let syn::Expr::Lit(syn::ExprLit {
1518                            lit: syn::Lit::Str(lit_str),
1519                            ..
1520                        }) = nv.value
1521                        {
1522                            template = Some(lit_str.value());
1523                        }
1524                    }
1525                    _ => {}
1526                }
1527            }
1528
1529            if let (Some(target), Some(tmpl)) = (target_type, template) {
1530                return (target, tmpl);
1531            }
1532        }
1533    }
1534
1535    panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
1536}
1537
1538/// A procedural attribute macro that generates prompt-building functions and extractor structs for intent enums.
1539///
1540/// This macro should be applied to an enum to generate:
1541/// 1. A prompt-building function that incorporates enum documentation
1542/// 2. An extractor struct that implements `IntentExtractor`
1543///
1544/// # Requirements
1545///
1546/// The enum must have an `#[intent(...)]` attribute with:
1547/// - `prompt`: The prompt template (supports Jinja-style variables)
1548/// - `extractor_tag`: The tag to use for extraction
1549///
1550/// # Example
1551///
1552/// ```rust,ignore
1553/// #[define_intent]
1554/// #[intent(
1555///     prompt = "Analyze the intent: {{ user_input }}",
1556///     extractor_tag = "intent"
1557/// )]
1558/// enum MyIntent {
1559///     /// Create a new item
1560///     Create,
1561///     /// Update an existing item
1562///     Update,
1563///     /// Delete an item
1564///     Delete,
1565/// }
1566/// ```
1567///
1568/// This will generate:
1569/// - `pub fn build_my_intent_prompt(user_input: &str) -> String`
1570/// - `pub struct MyIntentExtractor;` with `IntentExtractor<MyIntent>` implementation
1571#[proc_macro_attribute]
1572pub fn define_intent(_attr: TokenStream, item: TokenStream) -> TokenStream {
1573    let input = parse_macro_input!(item as DeriveInput);
1574
1575    // Verify this is an enum
1576    let enum_data = match &input.data {
1577        Data::Enum(data) => data,
1578        _ => {
1579            return syn::Error::new(
1580                input.ident.span(),
1581                "`#[define_intent]` can only be applied to enums",
1582            )
1583            .to_compile_error()
1584            .into();
1585        }
1586    };
1587
1588    // Parse the #[intent(...)] attribute
1589    let mut prompt_template = None;
1590    let mut extractor_tag = None;
1591    let mut mode = None;
1592
1593    for attr in &input.attrs {
1594        if attr.path().is_ident("intent")
1595            && let Ok(metas) =
1596                attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1597        {
1598            for meta in metas {
1599                match meta {
1600                    Meta::NameValue(nv) if nv.path.is_ident("prompt") => {
1601                        if let syn::Expr::Lit(syn::ExprLit {
1602                            lit: syn::Lit::Str(lit_str),
1603                            ..
1604                        }) = nv.value
1605                        {
1606                            prompt_template = Some(lit_str.value());
1607                        }
1608                    }
1609                    Meta::NameValue(nv) if nv.path.is_ident("extractor_tag") => {
1610                        if let syn::Expr::Lit(syn::ExprLit {
1611                            lit: syn::Lit::Str(lit_str),
1612                            ..
1613                        }) = nv.value
1614                        {
1615                            extractor_tag = Some(lit_str.value());
1616                        }
1617                    }
1618                    Meta::NameValue(nv) if nv.path.is_ident("mode") => {
1619                        if let syn::Expr::Lit(syn::ExprLit {
1620                            lit: syn::Lit::Str(lit_str),
1621                            ..
1622                        }) = nv.value
1623                        {
1624                            mode = Some(lit_str.value());
1625                        }
1626                    }
1627                    _ => {}
1628                }
1629            }
1630        }
1631    }
1632
1633    // Parse the mode parameter (default to "single")
1634    let mode = mode.unwrap_or_else(|| "single".to_string());
1635
1636    // Validate mode
1637    if mode != "single" && mode != "multi_tag" {
1638        return syn::Error::new(
1639            input.ident.span(),
1640            "`mode` must be either \"single\" or \"multi_tag\"",
1641        )
1642        .to_compile_error()
1643        .into();
1644    }
1645
1646    // Validate required attributes
1647    let prompt_template = match prompt_template {
1648        Some(p) => p,
1649        None => {
1650            return syn::Error::new(
1651                input.ident.span(),
1652                "`#[intent(...)]` attribute must include `prompt = \"...\"`",
1653            )
1654            .to_compile_error()
1655            .into();
1656        }
1657    };
1658
1659    // Handle multi_tag mode
1660    if mode == "multi_tag" {
1661        let enum_name = &input.ident;
1662        let actions_doc = generate_multi_tag_actions_doc(&enum_data.variants);
1663        return generate_multi_tag_output(
1664            &input,
1665            enum_name,
1666            enum_data,
1667            prompt_template,
1668            actions_doc,
1669        );
1670    }
1671
1672    // Continue with single mode logic
1673    let extractor_tag = match extractor_tag {
1674        Some(t) => t,
1675        None => {
1676            return syn::Error::new(
1677                input.ident.span(),
1678                "`#[intent(...)]` attribute must include `extractor_tag = \"...\"`",
1679            )
1680            .to_compile_error()
1681            .into();
1682        }
1683    };
1684
1685    // Generate the intents documentation
1686    let enum_name = &input.ident;
1687    let enum_docs = extract_doc_comments(&input.attrs);
1688
1689    let mut intents_doc_lines = Vec::new();
1690
1691    // Add enum description if present
1692    if !enum_docs.is_empty() {
1693        intents_doc_lines.push(format!("{}: {}", enum_name, enum_docs));
1694    } else {
1695        intents_doc_lines.push(format!("{}:", enum_name));
1696    }
1697    intents_doc_lines.push(String::new()); // Empty line
1698    intents_doc_lines.push("Possible values:".to_string());
1699
1700    // Add each variant with its documentation
1701    for variant in &enum_data.variants {
1702        let variant_name = &variant.ident;
1703        let variant_docs = extract_doc_comments(&variant.attrs);
1704
1705        if !variant_docs.is_empty() {
1706            intents_doc_lines.push(format!("- {}: {}", variant_name, variant_docs));
1707        } else {
1708            intents_doc_lines.push(format!("- {}", variant_name));
1709        }
1710    }
1711
1712    let intents_doc_str = intents_doc_lines.join("\n");
1713
1714    // Parse template variables (excluding intents_doc which we'll inject)
1715    let placeholders = parse_template_placeholders_with_mode(&prompt_template);
1716    let user_variables: Vec<String> = placeholders
1717        .iter()
1718        .filter_map(|(name, _)| {
1719            if name != "intents_doc" {
1720                Some(name.clone())
1721            } else {
1722                None
1723            }
1724        })
1725        .collect();
1726
1727    // Generate function name (snake_case)
1728    let enum_name_str = enum_name.to_string();
1729    let snake_case_name = to_snake_case(&enum_name_str);
1730    let function_name = syn::Ident::new(
1731        &format!("build_{}_prompt", snake_case_name),
1732        proc_macro2::Span::call_site(),
1733    );
1734
1735    // Generate function parameters (all &str for simplicity)
1736    let function_params: Vec<proc_macro2::TokenStream> = user_variables
1737        .iter()
1738        .map(|var| {
1739            let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
1740            quote! { #ident: &str }
1741        })
1742        .collect();
1743
1744    // Generate context insertions
1745    let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
1746        .iter()
1747        .map(|var| {
1748            let var_str = var.clone();
1749            let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
1750            quote! {
1751                __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
1752            }
1753        })
1754        .collect();
1755
1756    // Template is already in Jinja syntax, no conversion needed
1757    let converted_template = prompt_template.clone();
1758
1759    // Generate extractor struct name
1760    let extractor_name = syn::Ident::new(
1761        &format!("{}Extractor", enum_name),
1762        proc_macro2::Span::call_site(),
1763    );
1764
1765    // Filter out the #[intent(...)] attribute from the enum attributes
1766    let filtered_attrs: Vec<_> = input
1767        .attrs
1768        .iter()
1769        .filter(|attr| !attr.path().is_ident("intent"))
1770        .collect();
1771
1772    // Rebuild the enum with filtered attributes
1773    let vis = &input.vis;
1774    let generics = &input.generics;
1775    let variants = &enum_data.variants;
1776    let enum_output = quote! {
1777        #(#filtered_attrs)*
1778        #vis enum #enum_name #generics {
1779            #variants
1780        }
1781    };
1782
1783    // Generate the complete output
1784    let expanded = quote! {
1785        // Output the enum without the #[intent(...)] attribute
1786        #enum_output
1787
1788        // Generate the prompt-building function
1789        pub fn #function_name(#(#function_params),*) -> String {
1790            let mut env = minijinja::Environment::new();
1791            env.add_template("prompt", #converted_template)
1792                .expect("Failed to parse intent prompt template");
1793
1794            let tmpl = env.get_template("prompt").unwrap();
1795
1796            let mut __template_context = std::collections::HashMap::new();
1797
1798            // Add intents_doc
1799            __template_context.insert("intents_doc".to_string(), minijinja::Value::from(#intents_doc_str));
1800
1801            // Add user-provided variables
1802            #(#context_insertions)*
1803
1804            tmpl.render(&__template_context)
1805                .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
1806        }
1807
1808        // Generate the extractor struct
1809        pub struct #extractor_name;
1810
1811        impl #extractor_name {
1812            pub const EXTRACTOR_TAG: &'static str = #extractor_tag;
1813        }
1814
1815        impl llm_toolkit::intent::IntentExtractor<#enum_name> for #extractor_name {
1816            fn extract_intent(&self, response: &str) -> Result<#enum_name, llm_toolkit::intent::IntentExtractionError> {
1817                // Use the common extraction function with our tag
1818                llm_toolkit::intent::extract_intent_from_response(response, Self::EXTRACTOR_TAG)
1819            }
1820        }
1821    };
1822
1823    TokenStream::from(expanded)
1824}
1825
1826/// Convert PascalCase to snake_case
1827fn to_snake_case(s: &str) -> String {
1828    let mut result = String::new();
1829    let mut prev_upper = false;
1830
1831    for (i, ch) in s.chars().enumerate() {
1832        if ch.is_uppercase() {
1833            if i > 0 && !prev_upper {
1834                result.push('_');
1835            }
1836            result.push(ch.to_lowercase().next().unwrap());
1837            prev_upper = true;
1838        } else {
1839            result.push(ch);
1840            prev_upper = false;
1841        }
1842    }
1843
1844    result
1845}
1846
1847/// Parse #[action(...)] attributes for enum variants
1848#[derive(Debug, Default)]
1849struct ActionAttrs {
1850    tag: Option<String>,
1851}
1852
1853fn parse_action_attrs(attrs: &[syn::Attribute]) -> ActionAttrs {
1854    let mut result = ActionAttrs::default();
1855
1856    for attr in attrs {
1857        if attr.path().is_ident("action")
1858            && let Ok(meta_list) = attr.meta.require_list()
1859            && let Ok(metas) =
1860                meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1861        {
1862            for meta in metas {
1863                if let Meta::NameValue(nv) = meta
1864                    && nv.path.is_ident("tag")
1865                    && let syn::Expr::Lit(syn::ExprLit {
1866                        lit: syn::Lit::Str(lit_str),
1867                        ..
1868                    }) = nv.value
1869                {
1870                    result.tag = Some(lit_str.value());
1871                }
1872            }
1873        }
1874    }
1875
1876    result
1877}
1878
1879/// Parse #[action(...)] attributes for struct fields in variants
1880#[derive(Debug, Default)]
1881struct FieldActionAttrs {
1882    is_attribute: bool,
1883    is_inner_text: bool,
1884}
1885
1886fn parse_field_action_attrs(attrs: &[syn::Attribute]) -> FieldActionAttrs {
1887    let mut result = FieldActionAttrs::default();
1888
1889    for attr in attrs {
1890        if attr.path().is_ident("action")
1891            && let Ok(meta_list) = attr.meta.require_list()
1892        {
1893            let tokens_str = meta_list.tokens.to_string();
1894            if tokens_str == "attribute" {
1895                result.is_attribute = true;
1896            } else if tokens_str == "inner_text" {
1897                result.is_inner_text = true;
1898            }
1899        }
1900    }
1901
1902    result
1903}
1904
1905/// Generate actions_doc for multi_tag mode
1906fn generate_multi_tag_actions_doc(
1907    variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
1908) -> String {
1909    let mut doc_lines = Vec::new();
1910
1911    for variant in variants {
1912        let action_attrs = parse_action_attrs(&variant.attrs);
1913
1914        if let Some(tag) = action_attrs.tag {
1915            let variant_docs = extract_doc_comments(&variant.attrs);
1916
1917            match &variant.fields {
1918                syn::Fields::Unit => {
1919                    // Simple tag without parameters
1920                    doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
1921                }
1922                syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
1923                    // Tuple variant with inner text
1924                    doc_lines.push(format!("- `<{}>...</{}>`: {}", tag, tag, variant_docs));
1925                }
1926                syn::Fields::Named(fields) => {
1927                    // Struct variant with attributes and/or inner text
1928                    let mut attrs_str = Vec::new();
1929                    let mut has_inner_text = false;
1930
1931                    for field in &fields.named {
1932                        let field_name = field.ident.as_ref().unwrap();
1933                        let field_attrs = parse_field_action_attrs(&field.attrs);
1934
1935                        if field_attrs.is_attribute {
1936                            attrs_str.push(format!("{}=\"...\"", field_name));
1937                        } else if field_attrs.is_inner_text {
1938                            has_inner_text = true;
1939                        }
1940                    }
1941
1942                    let attrs_part = if !attrs_str.is_empty() {
1943                        format!(" {}", attrs_str.join(" "))
1944                    } else {
1945                        String::new()
1946                    };
1947
1948                    if has_inner_text {
1949                        doc_lines.push(format!(
1950                            "- `<{}{}>...</{}>`: {}",
1951                            tag, attrs_part, tag, variant_docs
1952                        ));
1953                    } else if !attrs_str.is_empty() {
1954                        doc_lines.push(format!("- `<{}{} />`: {}", tag, attrs_part, variant_docs));
1955                    } else {
1956                        doc_lines.push(format!("- `<{} />`: {}", tag, variant_docs));
1957                    }
1958
1959                    // Add field documentation
1960                    for field in &fields.named {
1961                        let field_name = field.ident.as_ref().unwrap();
1962                        let field_attrs = parse_field_action_attrs(&field.attrs);
1963                        let field_docs = extract_doc_comments(&field.attrs);
1964
1965                        if field_attrs.is_attribute {
1966                            doc_lines
1967                                .push(format!("  - `{}` (attribute): {}", field_name, field_docs));
1968                        } else if field_attrs.is_inner_text {
1969                            doc_lines
1970                                .push(format!("  - `{}` (inner_text): {}", field_name, field_docs));
1971                        }
1972                    }
1973                }
1974                _ => {
1975                    // Other field types not supported
1976                }
1977            }
1978        }
1979    }
1980
1981    doc_lines.join("\n")
1982}
1983
1984/// Generate regex for matching any of the defined action tags
1985fn generate_tags_regex(
1986    variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
1987) -> String {
1988    let mut tag_names = Vec::new();
1989
1990    for variant in variants {
1991        let action_attrs = parse_action_attrs(&variant.attrs);
1992        if let Some(tag) = action_attrs.tag {
1993            tag_names.push(tag);
1994        }
1995    }
1996
1997    if tag_names.is_empty() {
1998        return String::new();
1999    }
2000
2001    let tags_pattern = tag_names.join("|");
2002    // Match both self-closing tags like <Tag /> and content-based tags like <Tag>...</Tag>
2003    // (?is) enables case-insensitive and single-line mode where . matches newlines
2004    format!(
2005        r"(?is)<(?:{})\b[^>]*/>|<(?:{})\b[^>]*>.*?</(?:{})>",
2006        tags_pattern, tags_pattern, tags_pattern
2007    )
2008}
2009
2010/// Generate output for multi_tag mode
2011fn generate_multi_tag_output(
2012    input: &DeriveInput,
2013    enum_name: &syn::Ident,
2014    enum_data: &syn::DataEnum,
2015    prompt_template: String,
2016    actions_doc: String,
2017) -> TokenStream {
2018    // Parse template placeholders
2019    let placeholders = parse_template_placeholders_with_mode(&prompt_template);
2020    let user_variables: Vec<String> = placeholders
2021        .iter()
2022        .filter_map(|(name, _)| {
2023            if name != "actions_doc" {
2024                Some(name.clone())
2025            } else {
2026                None
2027            }
2028        })
2029        .collect();
2030
2031    // Generate function name (snake_case)
2032    let enum_name_str = enum_name.to_string();
2033    let snake_case_name = to_snake_case(&enum_name_str);
2034    let function_name = syn::Ident::new(
2035        &format!("build_{}_prompt", snake_case_name),
2036        proc_macro2::Span::call_site(),
2037    );
2038
2039    // Generate function parameters (all &str for simplicity)
2040    let function_params: Vec<proc_macro2::TokenStream> = user_variables
2041        .iter()
2042        .map(|var| {
2043            let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2044            quote! { #ident: &str }
2045        })
2046        .collect();
2047
2048    // Generate context insertions
2049    let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
2050        .iter()
2051        .map(|var| {
2052            let var_str = var.clone();
2053            let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
2054            quote! {
2055                __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
2056            }
2057        })
2058        .collect();
2059
2060    // Generate extractor struct name
2061    let extractor_name = syn::Ident::new(
2062        &format!("{}Extractor", enum_name),
2063        proc_macro2::Span::call_site(),
2064    );
2065
2066    // Filter out the #[intent(...)] and #[action(...)] attributes
2067    let filtered_attrs: Vec<_> = input
2068        .attrs
2069        .iter()
2070        .filter(|attr| !attr.path().is_ident("intent"))
2071        .collect();
2072
2073    // Filter action attributes from variants
2074    let filtered_variants: Vec<proc_macro2::TokenStream> = enum_data
2075        .variants
2076        .iter()
2077        .map(|variant| {
2078            let variant_name = &variant.ident;
2079            let variant_attrs: Vec<_> = variant
2080                .attrs
2081                .iter()
2082                .filter(|attr| !attr.path().is_ident("action"))
2083                .collect();
2084            let fields = &variant.fields;
2085
2086            // Filter field attributes
2087            let filtered_fields = match fields {
2088                syn::Fields::Named(named_fields) => {
2089                    let filtered: Vec<_> = named_fields
2090                        .named
2091                        .iter()
2092                        .map(|field| {
2093                            let field_name = &field.ident;
2094                            let field_type = &field.ty;
2095                            let field_vis = &field.vis;
2096                            let filtered_attrs: Vec<_> = field
2097                                .attrs
2098                                .iter()
2099                                .filter(|attr| !attr.path().is_ident("action"))
2100                                .collect();
2101                            quote! {
2102                                #(#filtered_attrs)*
2103                                #field_vis #field_name: #field_type
2104                            }
2105                        })
2106                        .collect();
2107                    quote! { { #(#filtered,)* } }
2108                }
2109                syn::Fields::Unnamed(unnamed_fields) => {
2110                    let types: Vec<_> = unnamed_fields
2111                        .unnamed
2112                        .iter()
2113                        .map(|field| {
2114                            let field_type = &field.ty;
2115                            quote! { #field_type }
2116                        })
2117                        .collect();
2118                    quote! { (#(#types),*) }
2119                }
2120                syn::Fields::Unit => quote! {},
2121            };
2122
2123            quote! {
2124                #(#variant_attrs)*
2125                #variant_name #filtered_fields
2126            }
2127        })
2128        .collect();
2129
2130    let vis = &input.vis;
2131    let generics = &input.generics;
2132
2133    // Generate XML parsing logic for extract_actions
2134    let parsing_arms = generate_parsing_arms(&enum_data.variants, enum_name);
2135
2136    // Generate the regex pattern for matching tags
2137    let tags_regex = generate_tags_regex(&enum_data.variants);
2138
2139    let expanded = quote! {
2140        // Output the enum without the #[intent(...)] and #[action(...)] attributes
2141        #(#filtered_attrs)*
2142        #vis enum #enum_name #generics {
2143            #(#filtered_variants),*
2144        }
2145
2146        // Generate the prompt-building function
2147        pub fn #function_name(#(#function_params),*) -> String {
2148            let mut env = minijinja::Environment::new();
2149            env.add_template("prompt", #prompt_template)
2150                .expect("Failed to parse intent prompt template");
2151
2152            let tmpl = env.get_template("prompt").unwrap();
2153
2154            let mut __template_context = std::collections::HashMap::new();
2155
2156            // Add actions_doc
2157            __template_context.insert("actions_doc".to_string(), minijinja::Value::from(#actions_doc));
2158
2159            // Add user-provided variables
2160            #(#context_insertions)*
2161
2162            tmpl.render(&__template_context)
2163                .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
2164        }
2165
2166        // Generate the extractor struct
2167        pub struct #extractor_name;
2168
2169        impl #extractor_name {
2170            fn parse_single_action(&self, text: &str) -> Option<#enum_name> {
2171                use ::quick_xml::events::Event;
2172                use ::quick_xml::Reader;
2173
2174                let mut actions = Vec::new();
2175                let mut reader = Reader::from_str(text);
2176                reader.config_mut().trim_text(true);
2177
2178                let mut buf = Vec::new();
2179
2180                loop {
2181                    match reader.read_event_into(&mut buf) {
2182                        Ok(Event::Start(e)) => {
2183                            let owned_e = e.into_owned();
2184                            let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2185                            let is_empty = false;
2186
2187                            #parsing_arms
2188                        }
2189                        Ok(Event::Empty(e)) => {
2190                            let owned_e = e.into_owned();
2191                            let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2192                            let is_empty = true;
2193
2194                            #parsing_arms
2195                        }
2196                        Ok(Event::Eof) => break,
2197                        Err(_) => {
2198                            // Silently ignore XML parsing errors
2199                            break;
2200                        }
2201                        _ => {}
2202                    }
2203                    buf.clear();
2204                }
2205
2206                actions.into_iter().next()
2207            }
2208
2209            pub fn extract_actions(&self, text: &str) -> Result<Vec<#enum_name>, llm_toolkit::intent::IntentError> {
2210                use ::quick_xml::events::Event;
2211                use ::quick_xml::Reader;
2212
2213                let mut actions = Vec::new();
2214                let mut reader = Reader::from_str(text);
2215                reader.config_mut().trim_text(true);
2216
2217                let mut buf = Vec::new();
2218
2219                loop {
2220                    match reader.read_event_into(&mut buf) {
2221                        Ok(Event::Start(e)) => {
2222                            let owned_e = e.into_owned();
2223                            let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2224                            let is_empty = false;
2225
2226                            #parsing_arms
2227                        }
2228                        Ok(Event::Empty(e)) => {
2229                            let owned_e = e.into_owned();
2230                            let tag_name = String::from_utf8_lossy(owned_e.name().as_ref()).to_string();
2231                            let is_empty = true;
2232
2233                            #parsing_arms
2234                        }
2235                        Ok(Event::Eof) => break,
2236                        Err(_) => {
2237                            // Silently ignore XML parsing errors
2238                            break;
2239                        }
2240                        _ => {}
2241                    }
2242                    buf.clear();
2243                }
2244
2245                Ok(actions)
2246            }
2247
2248            pub fn transform_actions<F>(&self, text: &str, mut transformer: F) -> String
2249            where
2250                F: FnMut(#enum_name) -> String,
2251            {
2252                use ::regex::Regex;
2253
2254                let regex_pattern = #tags_regex;
2255                if regex_pattern.is_empty() {
2256                    return text.to_string();
2257                }
2258
2259                let re = Regex::new(&regex_pattern).unwrap_or_else(|e| {
2260                    panic!("Failed to compile regex for action tags: {}", e);
2261                });
2262
2263                re.replace_all(text, |caps: &::regex::Captures| {
2264                    let matched = caps.get(0).map(|m| m.as_str()).unwrap_or("");
2265
2266                    // Try to parse the matched tag as an action
2267                    if let Some(action) = self.parse_single_action(matched) {
2268                        transformer(action)
2269                    } else {
2270                        // If parsing fails, return the original text
2271                        matched.to_string()
2272                    }
2273                }).to_string()
2274            }
2275
2276            pub fn strip_actions(&self, text: &str) -> String {
2277                self.transform_actions(text, |_| String::new())
2278            }
2279        }
2280    };
2281
2282    TokenStream::from(expanded)
2283}
2284
2285/// Generate parsing arms for XML extraction
2286fn generate_parsing_arms(
2287    variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
2288    enum_name: &syn::Ident,
2289) -> proc_macro2::TokenStream {
2290    let mut arms = Vec::new();
2291
2292    for variant in variants {
2293        let variant_name = &variant.ident;
2294        let action_attrs = parse_action_attrs(&variant.attrs);
2295
2296        if let Some(tag) = action_attrs.tag {
2297            match &variant.fields {
2298                syn::Fields::Unit => {
2299                    // Simple tag without parameters
2300                    arms.push(quote! {
2301                        if &tag_name == #tag {
2302                            actions.push(#enum_name::#variant_name);
2303                        }
2304                    });
2305                }
2306                syn::Fields::Unnamed(_fields) => {
2307                    // Tuple variant with inner text - use reader.read_text()
2308                    arms.push(quote! {
2309                        if &tag_name == #tag && !is_empty {
2310                            // Use read_text to get inner text as owned String
2311                            match reader.read_text(owned_e.name()) {
2312                                Ok(text) => {
2313                                    actions.push(#enum_name::#variant_name(text.to_string()));
2314                                }
2315                                Err(_) => {
2316                                    // If reading text fails, push empty string
2317                                    actions.push(#enum_name::#variant_name(String::new()));
2318                                }
2319                            }
2320                        }
2321                    });
2322                }
2323                syn::Fields::Named(fields) => {
2324                    // Struct variant with attributes and/or inner text
2325                    let mut field_names = Vec::new();
2326                    let mut has_inner_text_field = None;
2327
2328                    for field in &fields.named {
2329                        let field_name = field.ident.as_ref().unwrap();
2330                        let field_attrs = parse_field_action_attrs(&field.attrs);
2331
2332                        if field_attrs.is_attribute {
2333                            field_names.push(field_name.clone());
2334                        } else if field_attrs.is_inner_text {
2335                            has_inner_text_field = Some(field_name.clone());
2336                        }
2337                    }
2338
2339                    if let Some(inner_text_field) = has_inner_text_field {
2340                        // Handle inner text
2341                        // Build attribute extraction code
2342                        let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2343                            quote! {
2344                                let mut #field_name = String::new();
2345                                for attr in owned_e.attributes() {
2346                                    if let Ok(attr) = attr {
2347                                        if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2348                                            #field_name = String::from_utf8_lossy(&attr.value).to_string();
2349                                            break;
2350                                        }
2351                                    }
2352                                }
2353                            }
2354                        }).collect();
2355
2356                        arms.push(quote! {
2357                            if &tag_name == #tag {
2358                                #(#attr_extractions)*
2359
2360                                // Check if it's a self-closing tag
2361                                if is_empty {
2362                                    let #inner_text_field = String::new();
2363                                    actions.push(#enum_name::#variant_name {
2364                                        #(#field_names,)*
2365                                        #inner_text_field,
2366                                    });
2367                                } else {
2368                                    // Use read_text to get inner text as owned String
2369                                    match reader.read_text(owned_e.name()) {
2370                                        Ok(text) => {
2371                                            let #inner_text_field = text.to_string();
2372                                            actions.push(#enum_name::#variant_name {
2373                                                #(#field_names,)*
2374                                                #inner_text_field,
2375                                            });
2376                                        }
2377                                        Err(_) => {
2378                                            // If reading text fails, push with empty string
2379                                            let #inner_text_field = String::new();
2380                                            actions.push(#enum_name::#variant_name {
2381                                                #(#field_names,)*
2382                                                #inner_text_field,
2383                                            });
2384                                        }
2385                                    }
2386                                }
2387                            }
2388                        });
2389                    } else {
2390                        // Only attributes
2391                        let attr_extractions: Vec<_> = field_names.iter().map(|field_name| {
2392                            quote! {
2393                                let mut #field_name = String::new();
2394                                for attr in owned_e.attributes() {
2395                                    if let Ok(attr) = attr {
2396                                        if attr.key.as_ref() == stringify!(#field_name).as_bytes() {
2397                                            #field_name = String::from_utf8_lossy(&attr.value).to_string();
2398                                            break;
2399                                        }
2400                                    }
2401                                }
2402                            }
2403                        }).collect();
2404
2405                        arms.push(quote! {
2406                            if &tag_name == #tag {
2407                                #(#attr_extractions)*
2408                                actions.push(#enum_name::#variant_name {
2409                                    #(#field_names),*
2410                                });
2411                            }
2412                        });
2413                    }
2414                }
2415            }
2416        }
2417    }
2418
2419    quote! {
2420        #(#arms)*
2421    }
2422}
2423
2424/// Derives the `ToPromptFor` trait for a struct
2425#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
2426pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
2427    let input = parse_macro_input!(input as DeriveInput);
2428
2429    // Parse the struct-level prompt_for attribute
2430    let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
2431
2432    let struct_name = &input.ident;
2433    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2434
2435    // Parse the template to find placeholders
2436    let placeholders = parse_template_placeholders_with_mode(&template);
2437
2438    // Convert template to minijinja syntax and build context generation code
2439    let mut converted_template = template.clone();
2440    let mut context_fields = Vec::new();
2441
2442    // Get struct fields for validation
2443    let fields = match &input.data {
2444        Data::Struct(data_struct) => match &data_struct.fields {
2445            syn::Fields::Named(fields) => &fields.named,
2446            _ => panic!("ToPromptFor is only supported for structs with named fields"),
2447        },
2448        _ => panic!("ToPromptFor is only supported for structs"),
2449    };
2450
2451    // Check if the struct has mode support (has #[prompt(mode = ...)] attribute)
2452    let has_mode_support = input.attrs.iter().any(|attr| {
2453        if attr.path().is_ident("prompt")
2454            && let Ok(metas) =
2455                attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
2456        {
2457            for meta in metas {
2458                if let Meta::NameValue(nv) = meta
2459                    && nv.path.is_ident("mode")
2460                {
2461                    return true;
2462                }
2463            }
2464        }
2465        false
2466    });
2467
2468    // Process each placeholder
2469    for (placeholder_name, mode_opt) in &placeholders {
2470        if placeholder_name == "self" {
2471            if let Some(specific_mode) = mode_opt {
2472                // {self:some_mode} - use a unique key
2473                let unique_key = format!("self__{}", specific_mode);
2474
2475                // Replace {{ self:mode }} with {{ self__mode }} in template
2476                let pattern = format!("{{{{ self:{} }}}}", specific_mode);
2477                let replacement = format!("{{{{ {} }}}}", unique_key);
2478                converted_template = converted_template.replace(&pattern, &replacement);
2479
2480                // Add to context with the specific mode
2481                context_fields.push(quote! {
2482                    context.insert(
2483                        #unique_key.to_string(),
2484                        minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
2485                    );
2486                });
2487            } else {
2488                // {{self}} - already in correct format, no replacement needed
2489
2490                if has_mode_support {
2491                    // If the struct has mode support, use to_prompt_with_mode with the mode parameter
2492                    context_fields.push(quote! {
2493                        context.insert(
2494                            "self".to_string(),
2495                            minijinja::Value::from(self.to_prompt_with_mode(mode))
2496                        );
2497                    });
2498                } else {
2499                    // If the struct doesn't have mode support, use to_prompt() which gives key-value format
2500                    context_fields.push(quote! {
2501                        context.insert(
2502                            "self".to_string(),
2503                            minijinja::Value::from(self.to_prompt())
2504                        );
2505                    });
2506                }
2507            }
2508        } else {
2509            // It's a field placeholder
2510            // Check if the field exists
2511            let field_exists = fields.iter().any(|f| {
2512                f.ident
2513                    .as_ref()
2514                    .is_some_and(|ident| ident == placeholder_name)
2515            });
2516
2517            if field_exists {
2518                let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
2519
2520                // {{field}} - already in correct format, no replacement needed
2521
2522                // Add field to context - serialize the field value
2523                context_fields.push(quote! {
2524                    context.insert(
2525                        #placeholder_name.to_string(),
2526                        minijinja::Value::from_serialize(&self.#field_ident)
2527                    );
2528                });
2529            }
2530            // If field doesn't exist, we'll let minijinja handle the error at runtime
2531        }
2532    }
2533
2534    let expanded = quote! {
2535        impl #impl_generics llm_toolkit::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
2536        where
2537            #target_type: serde::Serialize,
2538        {
2539            fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
2540                // Create minijinja environment and add template
2541                let mut env = minijinja::Environment::new();
2542                env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
2543                    panic!("Failed to parse template: {}", e)
2544                });
2545
2546                let tmpl = env.get_template("prompt").unwrap();
2547
2548                // Build context
2549                let mut context = std::collections::HashMap::new();
2550                // Add self to the context for field access in templates
2551                context.insert(
2552                    "self".to_string(),
2553                    minijinja::Value::from_serialize(self)
2554                );
2555                // Add target to the context
2556                context.insert(
2557                    "target".to_string(),
2558                    minijinja::Value::from_serialize(target)
2559                );
2560                #(#context_fields)*
2561
2562                // Render template
2563                tmpl.render(context).unwrap_or_else(|e| {
2564                    format!("Failed to render prompt: {}", e)
2565                })
2566            }
2567        }
2568    };
2569
2570    TokenStream::from(expanded)
2571}