llm_toolkit_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    Data, DeriveInput, Meta, Token,
5    parse::{Parse, ParseStream},
6    parse_macro_input,
7    punctuated::Punctuated,
8};
9
10/// Convert single brace syntax to double brace syntax for minijinja
11/// {field} -> {{field}}, but leave {{ and }} as is
12fn convert_to_minijinja_syntax(template: &str) -> String {
13    let mut result = String::new();
14    let mut chars = template.chars().peekable();
15
16    while let Some(ch) = chars.next() {
17        if ch == '{' {
18            // Check if it's already a double brace
19            if chars.peek() == Some(&'{') {
20                result.push(ch);
21                result.push(chars.next().unwrap());
22            } else {
23                // Single brace, convert to double
24                result.push_str("{{");
25            }
26        } else if ch == '}' {
27            // Check if it's already a double brace
28            if chars.peek() == Some(&'}') {
29                result.push(ch);
30                result.push(chars.next().unwrap());
31            } else {
32                // Single brace, convert to double
33                result.push_str("}}");
34            }
35        } else {
36            result.push(ch);
37        }
38    }
39
40    result
41}
42
43/// Parse template placeholders and extract field names with optional modes
44/// Returns a list of (field_name, optional_mode)
45fn parse_template_placeholders(template: &str) -> Vec<(String, Option<String>)> {
46    let mut placeholders = Vec::new();
47    let mut chars = template.chars().peekable();
48
49    while let Some(ch) = chars.next() {
50        if ch == '{' {
51            // Check if it's a double brace (minijinja style)
52            if chars.peek() == Some(&'{') {
53                chars.next(); // Skip the second brace
54
55                // Parse minijinja-style placeholder content
56                let mut placeholder = String::new();
57                let mut found_end = false;
58                loop {
59                    match chars.next() {
60                        Some('}') => {
61                            if chars.peek() == Some(&'}') {
62                                chars.next(); // Skip the second closing brace
63                                found_end = true;
64                                break;
65                            } else {
66                                placeholder.push('}');
67                            }
68                        }
69                        Some(ch) => placeholder.push(ch),
70                        None => break,
71                    }
72                }
73
74                if found_end {
75                    // Check if placeholder contains :mode syntax
76                    if let Some(colon_pos) = placeholder.find(':') {
77                        let field_name = placeholder[..colon_pos].trim().to_string();
78                        let mode = placeholder[colon_pos + 1..].trim().to_string();
79                        placeholders.push((field_name, Some(mode)));
80                    } else {
81                        placeholders.push((placeholder.trim().to_string(), None));
82                    }
83                }
84            } else {
85                // Single brace - parse legacy placeholder content
86                let mut placeholder = String::new();
87                for inner_ch in chars.by_ref() {
88                    if inner_ch == '}' {
89                        break;
90                    }
91                    placeholder.push(inner_ch);
92                }
93
94                // Check if placeholder contains :mode syntax
95                if let Some(colon_pos) = placeholder.find(':') {
96                    let field_name = placeholder[..colon_pos].trim().to_string();
97                    let mode = placeholder[colon_pos + 1..].trim().to_string();
98                    placeholders.push((field_name, Some(mode)));
99                } else {
100                    placeholders.push((placeholder.trim().to_string(), None));
101                }
102            }
103        }
104    }
105
106    placeholders
107}
108
109/// Extract doc comments from attributes
110fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
111    attrs
112        .iter()
113        .filter_map(|attr| {
114            if attr.path().is_ident("doc")
115                && let syn::Meta::NameValue(meta_name_value) = &attr.meta
116                && let syn::Expr::Lit(syn::ExprLit {
117                    lit: syn::Lit::Str(lit_str),
118                    ..
119                }) = &meta_name_value.value
120            {
121                return Some(lit_str.value());
122            }
123            None
124        })
125        .map(|s| s.trim().to_string())
126        .collect::<Vec<_>>()
127        .join(" ")
128}
129
130/// Generate example JSON representation for a struct
131fn generate_example_only_parts(
132    fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
133    has_default: bool,
134) -> proc_macro2::TokenStream {
135    let mut field_values = Vec::new();
136
137    for field in fields.iter() {
138        let field_name = field.ident.as_ref().unwrap();
139        let field_name_str = field_name.to_string();
140        let attrs = parse_field_prompt_attrs(&field.attrs);
141
142        // Skip if marked to skip
143        if attrs.skip {
144            continue;
145        }
146
147        // Check if field has example attribute
148        if let Some(example) = attrs.example {
149            // Use the provided example value
150            field_values.push(quote! {
151                json_obj.insert(#field_name_str.to_string(), serde_json::Value::String(#example.to_string()));
152            });
153        } else if has_default {
154            // Use Default value if available
155            field_values.push(quote! {
156                let default_value = serde_json::to_value(&default_instance.#field_name)
157                    .unwrap_or(serde_json::Value::Null);
158                json_obj.insert(#field_name_str.to_string(), default_value);
159            });
160        } else {
161            // Use self's actual value
162            field_values.push(quote! {
163                let value = serde_json::to_value(&self.#field_name)
164                    .unwrap_or(serde_json::Value::Null);
165                json_obj.insert(#field_name_str.to_string(), value);
166            });
167        }
168    }
169
170    if has_default {
171        quote! {
172            {
173                let default_instance = Self::default();
174                let mut json_obj = serde_json::Map::new();
175                #(#field_values)*
176                let json_value = serde_json::Value::Object(json_obj);
177                let json_str = serde_json::to_string_pretty(&json_value)
178                    .unwrap_or_else(|_| "{}".to_string());
179                vec![llm_toolkit::prompt::PromptPart::Text(json_str)]
180            }
181        }
182    } else {
183        quote! {
184            {
185                let mut json_obj = serde_json::Map::new();
186                #(#field_values)*
187                let json_value = serde_json::Value::Object(json_obj);
188                let json_str = serde_json::to_string_pretty(&json_value)
189                    .unwrap_or_else(|_| "{}".to_string());
190                vec![llm_toolkit::prompt::PromptPart::Text(json_str)]
191            }
192        }
193    }
194}
195
196/// Generate schema-only representation for a struct
197fn generate_schema_only_parts(
198    struct_name: &str,
199    struct_docs: &str,
200    fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
201) -> proc_macro2::TokenStream {
202    let mut schema_lines = vec![];
203
204    // Add header
205    if !struct_docs.is_empty() {
206        schema_lines.push(format!("### Schema for `{}`\n{}", struct_name, struct_docs));
207    } else {
208        schema_lines.push(format!("### Schema for `{}`", struct_name));
209    }
210
211    schema_lines.push("{".to_string());
212
213    // Process fields
214    for (i, field) in fields.iter().enumerate() {
215        let field_name = field.ident.as_ref().unwrap();
216        let attrs = parse_field_prompt_attrs(&field.attrs);
217
218        // Skip if marked to skip
219        if attrs.skip {
220            continue;
221        }
222
223        // Get field documentation
224        let field_docs = extract_doc_comments(&field.attrs);
225
226        // Determine the type representation
227        let type_str = format_type_for_schema(&field.ty);
228
229        // Build field line
230        let mut field_line = format!("  \"{}\": \"{}\"", field_name, type_str);
231
232        // Add comment if there's documentation
233        if !field_docs.is_empty() {
234            field_line.push_str(&format!(", // {}", field_docs));
235        }
236
237        // Add comma if not last field (accounting for skipped fields)
238        let remaining_fields = fields
239            .iter()
240            .skip(i + 1)
241            .filter(|f| {
242                let attrs = parse_field_prompt_attrs(&f.attrs);
243                !attrs.skip
244            })
245            .count();
246
247        if remaining_fields > 0 {
248            field_line.push(',');
249        }
250
251        schema_lines.push(field_line);
252    }
253
254    schema_lines.push("}".to_string());
255
256    let schema_str = schema_lines.join("\n");
257
258    quote! {
259        vec![llm_toolkit::prompt::PromptPart::Text(#schema_str.to_string())]
260    }
261}
262
263/// Format a type for schema representation
264fn format_type_for_schema(ty: &syn::Type) -> String {
265    // Simple type formatting - can be enhanced
266    match ty {
267        syn::Type::Path(type_path) => {
268            let path = &type_path.path;
269            if let Some(last_segment) = path.segments.last() {
270                let type_name = last_segment.ident.to_string();
271
272                // Handle Option<T>
273                if type_name == "Option"
274                    && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
275                    && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
276                {
277                    return format!("{} | null", format_type_for_schema(inner_type));
278                }
279
280                // Map common types
281                match type_name.as_str() {
282                    "String" | "str" => "string".to_string(),
283                    "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
284                    | "u64" | "u128" | "usize" => "number".to_string(),
285                    "f32" | "f64" => "number".to_string(),
286                    "bool" => "boolean".to_string(),
287                    "Vec" => {
288                        if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
289                            && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
290                        {
291                            return format!("{}[]", format_type_for_schema(inner_type));
292                        }
293                        "array".to_string()
294                    }
295                    _ => type_name.to_lowercase(),
296                }
297            } else {
298                "unknown".to_string()
299            }
300        }
301        _ => "unknown".to_string(),
302    }
303}
304
305/// Result of parsing prompt attribute
306enum PromptAttribute {
307    Skip,
308    Description(String),
309    None,
310}
311
312/// Parse #[prompt(...)] attribute on enum variant
313fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
314    for attr in attrs {
315        if attr.path().is_ident("prompt") {
316            // Check for #[prompt(skip)]
317            if let Ok(meta_list) = attr.meta.require_list() {
318                let tokens = &meta_list.tokens;
319                let tokens_str = tokens.to_string();
320                if tokens_str == "skip" {
321                    return PromptAttribute::Skip;
322                }
323            }
324
325            // Check for #[prompt("description")]
326            if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
327                return PromptAttribute::Description(lit_str.value());
328            }
329        }
330    }
331    PromptAttribute::None
332}
333
334/// Parsed field-level prompt attributes
335#[derive(Debug, Default)]
336struct FieldPromptAttrs {
337    skip: bool,
338    rename: Option<String>,
339    format_with: Option<String>,
340    image: bool,
341    example: Option<String>,
342}
343
344/// Parse #[prompt(...)] attributes for struct fields
345fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
346    let mut result = FieldPromptAttrs::default();
347
348    for attr in attrs {
349        if attr.path().is_ident("prompt") {
350            // Try to parse as meta list #[prompt(key = value, ...)]
351            if let Ok(meta_list) = attr.meta.require_list() {
352                // Parse the tokens inside the parentheses
353                if let Ok(metas) =
354                    meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
355                {
356                    for meta in metas {
357                        match meta {
358                            Meta::Path(path) if path.is_ident("skip") => {
359                                result.skip = true;
360                            }
361                            Meta::NameValue(nv) if nv.path.is_ident("rename") => {
362                                if let syn::Expr::Lit(syn::ExprLit {
363                                    lit: syn::Lit::Str(lit_str),
364                                    ..
365                                }) = nv.value
366                                {
367                                    result.rename = Some(lit_str.value());
368                                }
369                            }
370                            Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
371                                if let syn::Expr::Lit(syn::ExprLit {
372                                    lit: syn::Lit::Str(lit_str),
373                                    ..
374                                }) = nv.value
375                                {
376                                    result.format_with = Some(lit_str.value());
377                                }
378                            }
379                            Meta::Path(path) if path.is_ident("image") => {
380                                result.image = true;
381                            }
382                            Meta::NameValue(nv) if nv.path.is_ident("example") => {
383                                if let syn::Expr::Lit(syn::ExprLit {
384                                    lit: syn::Lit::Str(lit_str),
385                                    ..
386                                }) = nv.value
387                                {
388                                    result.example = Some(lit_str.value());
389                                }
390                            }
391                            _ => {}
392                        }
393                    }
394                } else if meta_list.tokens.to_string() == "skip" {
395                    // Handle simple #[prompt(skip)] case
396                    result.skip = true;
397                } else if meta_list.tokens.to_string() == "image" {
398                    // Handle simple #[prompt(image)] case
399                    result.image = true;
400                }
401            }
402        }
403    }
404
405    result
406}
407
408/// Derives the `ToPrompt` trait for a struct or enum.
409///
410/// This macro provides two main functionalities depending on the type.
411///
412/// ## For Structs
413///
414/// It can generate a prompt based on a template string or by creating a key-value representation of the struct's fields.
415///
416/// ### Template-based Prompt
417///
418/// Use the `#[prompt(template = "...")]` attribute to provide a `minijinja` template. The struct fields will be available as variables in the template. The struct must also derive `serde::Serialize`.
419///
420/// ```rust,ignore
421/// #[derive(ToPrompt, Serialize)]
422/// #[prompt(template = "User {{ name }} is a {{ role }}.")]
423/// struct UserProfile {
424///     name: &'static str,
425///     role: &'static str,
426/// }
427/// ```
428///
429/// ### Tip: Handling Special Characters in Templates
430///
431/// When using raw string literals (e.g., `r#"..."#`) for your templates, be aware of a potential parsing issue if your template content includes the `#` character. To avoid this, use a different number of `#` symbols for the raw string delimiter.
432///
433/// **Problematic Example:**
434/// ```rust,ignore
435/// // This might fail to parse correctly
436/// #[prompt(template = r#"{"color": "#FFFFFF"}"#)]
437/// struct Color { /* ... */ }
438/// ```
439///
440/// **Solution:**
441/// ```rust,ignore
442/// // Use r##"..."## to avoid ambiguity with the inner '#'
443/// #[prompt(template = r##"{"color": "#FFFFFF"}"##)]
444/// struct Color { /* ... */ }
445/// ```
446///
447/// ## For Enums
448///
449/// For enums, the macro generates a descriptive prompt based on doc comments and attributes, outlining the available variants. See the documentation on the `ToPrompt` trait for more details.
450#[proc_macro_derive(ToPrompt, attributes(prompt))]
451pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
452    let input = parse_macro_input!(input as DeriveInput);
453
454    // Check if this is a struct or enum
455    match &input.data {
456        Data::Enum(data_enum) => {
457            // For enums, generate prompt from doc comments
458            let enum_name = &input.ident;
459            let enum_docs = extract_doc_comments(&input.attrs);
460
461            let mut prompt_lines = Vec::new();
462
463            // Add enum description
464            if !enum_docs.is_empty() {
465                prompt_lines.push(format!("{}: {}", enum_name, enum_docs));
466            } else {
467                prompt_lines.push(format!("{}:", enum_name));
468            }
469            prompt_lines.push(String::new()); // Empty line
470            prompt_lines.push("Possible values:".to_string());
471
472            // Add each variant with its documentation based on priority
473            for variant in &data_enum.variants {
474                let variant_name = &variant.ident;
475
476                // Apply fallback logic with priority
477                match parse_prompt_attribute(&variant.attrs) {
478                    PromptAttribute::Skip => {
479                        // Skip this variant completely
480                        continue;
481                    }
482                    PromptAttribute::Description(desc) => {
483                        // Use custom description from #[prompt("...")]
484                        prompt_lines.push(format!("- {}: {}", variant_name, desc));
485                    }
486                    PromptAttribute::None => {
487                        // Fall back to doc comment or just variant name
488                        let variant_docs = extract_doc_comments(&variant.attrs);
489                        if !variant_docs.is_empty() {
490                            prompt_lines.push(format!("- {}: {}", variant_name, variant_docs));
491                        } else {
492                            prompt_lines.push(format!("- {}", variant_name));
493                        }
494                    }
495                }
496            }
497
498            let prompt_string = prompt_lines.join("\n");
499            let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
500
501            let expanded = quote! {
502                impl #impl_generics llm_toolkit::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
503                    fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
504                        vec![llm_toolkit::prompt::PromptPart::Text(#prompt_string.to_string())]
505                    }
506
507                    fn to_prompt(&self) -> String {
508                        #prompt_string.to_string()
509                    }
510                }
511            };
512
513            TokenStream::from(expanded)
514        }
515        Data::Struct(data_struct) => {
516            // Parse struct-level prompt attributes for template, template_file, mode, and validate
517            let mut template_attr = None;
518            let mut template_file_attr = None;
519            let mut mode_attr = None;
520            let mut validate_attr = false;
521
522            for attr in &input.attrs {
523                if attr.path().is_ident("prompt") {
524                    // Try to parse the attribute arguments
525                    if let Ok(metas) =
526                        attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
527                    {
528                        for meta in metas {
529                            match meta {
530                                Meta::NameValue(nv) if nv.path.is_ident("template") => {
531                                    if let syn::Expr::Lit(expr_lit) = nv.value
532                                        && let syn::Lit::Str(lit_str) = expr_lit.lit
533                                    {
534                                        template_attr = Some(lit_str.value());
535                                    }
536                                }
537                                Meta::NameValue(nv) if nv.path.is_ident("template_file") => {
538                                    if let syn::Expr::Lit(expr_lit) = nv.value
539                                        && let syn::Lit::Str(lit_str) = expr_lit.lit
540                                    {
541                                        template_file_attr = Some(lit_str.value());
542                                    }
543                                }
544                                Meta::NameValue(nv) if nv.path.is_ident("mode") => {
545                                    if let syn::Expr::Lit(expr_lit) = nv.value
546                                        && let syn::Lit::Str(lit_str) = expr_lit.lit
547                                    {
548                                        mode_attr = Some(lit_str.value());
549                                    }
550                                }
551                                Meta::NameValue(nv) if nv.path.is_ident("validate") => {
552                                    if let syn::Expr::Lit(expr_lit) = nv.value
553                                        && let syn::Lit::Bool(lit_bool) = expr_lit.lit
554                                    {
555                                        validate_attr = lit_bool.value();
556                                    }
557                                }
558                                _ => {}
559                            }
560                        }
561                    }
562                }
563            }
564
565            // Check for mutual exclusivity between template and template_file
566            if template_attr.is_some() && template_file_attr.is_some() {
567                return syn::Error::new(
568                    input.ident.span(),
569                    "The `template` and `template_file` attributes are mutually exclusive. Please use only one.",
570                ).to_compile_error().into();
571            }
572
573            // Load template from file if template_file is specified
574            let template_str = if let Some(file_path) = template_file_attr {
575                // Try multiple strategies to find the template file
576                // This is necessary to support both normal compilation and trybuild tests
577
578                let mut full_path = None;
579
580                // Strategy 1: Try relative to CARGO_MANIFEST_DIR (normal compilation)
581                if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
582                    // Check if this is a trybuild temporary directory
583                    let is_trybuild = manifest_dir.contains("target/tests/trybuild");
584
585                    if !is_trybuild {
586                        // Normal compilation - use CARGO_MANIFEST_DIR directly
587                        let candidate = std::path::Path::new(&manifest_dir).join(&file_path);
588                        if candidate.exists() {
589                            full_path = Some(candidate);
590                        }
591                    } else {
592                        // For trybuild, we need to find the original source directory
593                        // The manifest_dir looks like: .../target/tests/trybuild/llm-toolkit-macros
594                        // We need to get back to the original llm-toolkit-macros source directory
595
596                        // Extract the workspace root from the path
597                        if let Some(target_pos) = manifest_dir.find("/target/tests/trybuild") {
598                            let workspace_root = &manifest_dir[..target_pos];
599                            // Now construct the path to the original llm-toolkit-macros source
600                            let original_macros_dir = std::path::Path::new(workspace_root)
601                                .join("crates")
602                                .join("llm-toolkit-macros");
603
604                            let candidate = original_macros_dir.join(&file_path);
605                            if candidate.exists() {
606                                full_path = Some(candidate);
607                            }
608                        }
609                    }
610                }
611
612                // Strategy 2: Try as an absolute path or relative to current directory
613                if full_path.is_none() {
614                    let candidate = std::path::Path::new(&file_path).to_path_buf();
615                    if candidate.exists() {
616                        full_path = Some(candidate);
617                    }
618                }
619
620                // Strategy 3: For trybuild tests - try to find the file by looking in parent directories
621                // This handles the case where trybuild creates a temporary project
622                if full_path.is_none()
623                    && let Ok(current_dir) = std::env::current_dir()
624                {
625                    let mut search_dir = current_dir.as_path();
626                    // Search up to 10 levels up
627                    for _ in 0..10 {
628                        // Try from the llm-toolkit-macros directory
629                        let macros_dir = search_dir.join("crates/llm-toolkit-macros");
630                        if macros_dir.exists() {
631                            let candidate = macros_dir.join(&file_path);
632                            if candidate.exists() {
633                                full_path = Some(candidate);
634                                break;
635                            }
636                        }
637                        // Try directly
638                        let candidate = search_dir.join(&file_path);
639                        if candidate.exists() {
640                            full_path = Some(candidate);
641                            break;
642                        }
643                        if let Some(parent) = search_dir.parent() {
644                            search_dir = parent;
645                        } else {
646                            break;
647                        }
648                    }
649                }
650
651                // If we still haven't found the file, use the original path for a better error message
652                let final_path =
653                    full_path.unwrap_or_else(|| std::path::Path::new(&file_path).to_path_buf());
654
655                // Read the file at compile time
656                match std::fs::read_to_string(&final_path) {
657                    Ok(content) => Some(content),
658                    Err(e) => {
659                        return syn::Error::new(
660                            input.ident.span(),
661                            format!(
662                                "Failed to read template file '{}': {}",
663                                final_path.display(),
664                                e
665                            ),
666                        )
667                        .to_compile_error()
668                        .into();
669                    }
670                }
671            } else {
672                template_attr
673            };
674
675            // Perform validation if requested
676            if validate_attr && let Some(template) = &template_str {
677                // Validate Jinja syntax
678                let mut env = minijinja::Environment::new();
679                if let Err(e) = env.add_template("validation", template) {
680                    // Generate a compile warning using deprecated const hack
681                    let warning_msg =
682                        format!("Template validation warning: Invalid Jinja syntax - {}", e);
683                    let warning_ident = syn::Ident::new(
684                        "TEMPLATE_VALIDATION_WARNING",
685                        proc_macro2::Span::call_site(),
686                    );
687                    let _warning_tokens = quote! {
688                        #[deprecated(note = #warning_msg)]
689                        const #warning_ident: () = ();
690                        let _ = #warning_ident;
691                    };
692                    // We'll inject this warning into the generated code
693                    eprintln!("cargo:warning={}", warning_msg);
694                }
695
696                // Extract variables from template and check against struct fields
697                let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
698                    &fields.named
699                } else {
700                    panic!("Template validation is only supported for structs with named fields.");
701                };
702
703                let field_names: std::collections::HashSet<String> = fields
704                    .iter()
705                    .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
706                    .collect();
707
708                // Parse template placeholders
709                let placeholders = parse_template_placeholders(template);
710
711                for (placeholder_name, _mode) in &placeholders {
712                    if placeholder_name != "self" && !field_names.contains(placeholder_name) {
713                        let warning_msg = format!(
714                            "Template validation warning: Variable '{}' used in template but not found in struct fields",
715                            placeholder_name
716                        );
717                        eprintln!("cargo:warning={}", warning_msg);
718                    }
719                }
720            }
721
722            let name = input.ident;
723            let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
724
725            // Extract struct name and doc comment for use in schema generation
726            let struct_docs = extract_doc_comments(&input.attrs);
727
728            // Check if this is a mode-based struct (mode attribute present)
729            let is_mode_based =
730                mode_attr.is_some() || (template_str.is_none() && struct_docs.contains("mode"));
731
732            let expanded = if is_mode_based || mode_attr.is_some() {
733                // Mode-based generation: support schema_only, example_only, full
734                let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
735                    &fields.named
736                } else {
737                    panic!(
738                        "Mode-based prompt generation is only supported for structs with named fields."
739                    );
740                };
741
742                let struct_name_str = name.to_string();
743
744                // Check if struct derives Default
745                let has_default = input.attrs.iter().any(|attr| {
746                    if attr.path().is_ident("derive") {
747                        if let Ok(meta_list) = attr.meta.require_list() {
748                            let tokens_str = meta_list.tokens.to_string();
749                            tokens_str.contains("Default")
750                        } else {
751                            false
752                        }
753                    } else {
754                        false
755                    }
756                });
757
758                // Generate schema-only parts
759                let schema_parts =
760                    generate_schema_only_parts(&struct_name_str, &struct_docs, fields);
761
762                // Generate example parts
763                let example_parts = generate_example_only_parts(fields, has_default);
764
765                quote! {
766                    impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
767                        fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<llm_toolkit::prompt::PromptPart> {
768                            match mode {
769                                "schema_only" => #schema_parts,
770                                "example_only" => #example_parts,
771                                "full" | _ => {
772                                    // Combine schema and example
773                                    let mut parts = Vec::new();
774
775                                    // Add schema
776                                    let schema_parts = #schema_parts;
777                                    parts.extend(schema_parts);
778
779                                    // Add separator and example header
780                                    parts.push(llm_toolkit::prompt::PromptPart::Text("\n### Example".to_string()));
781                                    parts.push(llm_toolkit::prompt::PromptPart::Text(
782                                        format!("Here is an example of a valid `{}` object:", #struct_name_str)
783                                    ));
784
785                                    // Add example
786                                    let example_parts = #example_parts;
787                                    parts.extend(example_parts);
788
789                                    parts
790                                }
791                            }
792                        }
793
794                        fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
795                            self.to_prompt_parts_with_mode("full")
796                        }
797
798                        fn to_prompt(&self) -> String {
799                            self.to_prompt_parts()
800                                .into_iter()
801                                .filter_map(|part| match part {
802                                    llm_toolkit::prompt::PromptPart::Text(text) => Some(text),
803                                    _ => None,
804                                })
805                                .collect::<Vec<_>>()
806                                .join("\n")
807                        }
808                    }
809                }
810            } else if let Some(template) = template_str {
811                // Use template-based approach if template is provided
812                // Collect image fields separately for to_prompt_parts()
813                let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
814                    &fields.named
815                } else {
816                    panic!(
817                        "Template prompt generation is only supported for structs with named fields."
818                    );
819                };
820
821                // Parse template to detect mode syntax
822                let placeholders = parse_template_placeholders(&template);
823                // Only use custom mode processing if template actually contains :mode syntax
824                let has_mode_syntax = placeholders.iter().any(|(field_name, mode)| {
825                    mode.is_some()
826                        && fields
827                            .iter()
828                            .any(|f| f.ident.as_ref().unwrap() == field_name)
829                });
830
831                let mut image_field_parts = Vec::new();
832                for f in fields.iter() {
833                    let field_name = f.ident.as_ref().unwrap();
834                    let attrs = parse_field_prompt_attrs(&f.attrs);
835
836                    if attrs.image {
837                        // This field is marked as an image
838                        image_field_parts.push(quote! {
839                            parts.extend(self.#field_name.to_prompt_parts());
840                        });
841                    }
842                }
843
844                // Generate appropriate code based on whether mode syntax is used
845                if has_mode_syntax {
846                    // Build custom context for fields with mode specifications
847                    let mut context_fields = Vec::new();
848
849                    // Convert template to minijinja syntax, but preserve mode information
850                    // We'll replace {field:mode} with unique keys for each mode
851                    let mut converted_template = template.clone();
852
853                    // Process each placeholder
854                    for (field_name, mode_opt) in &placeholders {
855                        // Find the corresponding field
856                        let field_ident =
857                            syn::Ident::new(field_name, proc_macro2::Span::call_site());
858
859                        if let Some(mode) = mode_opt {
860                            // Create a unique key for this field:mode combination
861                            let unique_key = format!("{}__{}", field_name, mode);
862
863                            // Replace {field:mode} with {{field__mode}} in template
864                            let pattern = format!("{{{}:{}}}", field_name, mode);
865                            let replacement = format!("{{{{{}}}}}", unique_key);
866                            converted_template = converted_template.replace(&pattern, &replacement);
867
868                            // Field with mode specification
869                            context_fields.push(quote! {
870                                context.insert(
871                                    #unique_key.to_string(),
872                                    minijinja::Value::from(self.#field_ident.to_prompt_with_mode(#mode))
873                                );
874                            });
875                        } else {
876                            // Replace {field} with {{field}} in template
877                            let pattern = format!("{{{}}}", field_name);
878                            let replacement = format!("{{{{{}}}}}", field_name);
879                            converted_template = converted_template.replace(&pattern, &replacement);
880
881                            // Field without mode (use default)
882                            context_fields.push(quote! {
883                                context.insert(
884                                    #field_name.to_string(),
885                                    minijinja::Value::from(self.#field_ident.to_prompt())
886                                );
887                            });
888                        }
889                    }
890
891                    quote! {
892                        impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
893                            fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
894                                let mut parts = Vec::new();
895
896                                // Add image parts first
897                                #(#image_field_parts)*
898
899                                // Build custom context and render template
900                                let text = {
901                                    let mut env = minijinja::Environment::new();
902                                    env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
903                                        panic!("Failed to parse template: {}", e)
904                                    });
905
906                                    let tmpl = env.get_template("prompt").unwrap();
907
908                                    let mut context = std::collections::HashMap::new();
909                                    #(#context_fields)*
910
911                                    tmpl.render(context).unwrap_or_else(|e| {
912                                        format!("Failed to render prompt: {}", e)
913                                    })
914                                };
915
916                                if !text.is_empty() {
917                                    parts.push(llm_toolkit::prompt::PromptPart::Text(text));
918                                }
919
920                                parts
921                            }
922
923                            fn to_prompt(&self) -> String {
924                                // Same logic for to_prompt
925                                let mut env = minijinja::Environment::new();
926                                env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
927                                    panic!("Failed to parse template: {}", e)
928                                });
929
930                                let tmpl = env.get_template("prompt").unwrap();
931
932                                let mut context = std::collections::HashMap::new();
933                                #(#context_fields)*
934
935                                tmpl.render(context).unwrap_or_else(|e| {
936                                    format!("Failed to render prompt: {}", e)
937                                })
938                            }
939                        }
940                    }
941                } else {
942                    // No mode syntax, convert single braces to double braces for minijinja
943                    let converted_template = convert_to_minijinja_syntax(&template);
944
945                    quote! {
946                        impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
947                            fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
948                                let mut parts = Vec::new();
949
950                                // Add image parts first
951                                #(#image_field_parts)*
952
953                                // Add the rendered template as text
954                                let text = llm_toolkit::prompt::render_prompt(#converted_template, self).unwrap_or_else(|e| {
955                                    format!("Failed to render prompt: {}", e)
956                                });
957                                if !text.is_empty() {
958                                    parts.push(llm_toolkit::prompt::PromptPart::Text(text));
959                                }
960
961                                parts
962                            }
963
964                            fn to_prompt(&self) -> String {
965                                llm_toolkit::prompt::render_prompt(#converted_template, self).unwrap_or_else(|e| {
966                                    format!("Failed to render prompt: {}", e)
967                                })
968                            }
969                        }
970                    }
971                }
972            } else {
973                // Use default key-value format if no template is provided
974                // Now also generate to_prompt_parts() for multimodal support
975                let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
976                    &fields.named
977                } else {
978                    panic!(
979                        "Default prompt generation is only supported for structs with named fields."
980                    );
981                };
982
983                // Separate image fields from text fields
984                let mut text_field_parts = Vec::new();
985                let mut image_field_parts = Vec::new();
986
987                for f in fields.iter() {
988                    let field_name = f.ident.as_ref().unwrap();
989                    let attrs = parse_field_prompt_attrs(&f.attrs);
990
991                    // Skip if #[prompt(skip)] is present
992                    if attrs.skip {
993                        continue;
994                    }
995
996                    if attrs.image {
997                        // This field is marked as an image
998                        image_field_parts.push(quote! {
999                            parts.extend(self.#field_name.to_prompt_parts());
1000                        });
1001                    } else {
1002                        // This is a regular text field
1003                        // Determine the key based on priority:
1004                        // 1. #[prompt(rename = "new_name")]
1005                        // 2. Doc comment
1006                        // 3. Field name (fallback)
1007                        let key = if let Some(rename) = attrs.rename {
1008                            rename
1009                        } else {
1010                            let doc_comment = extract_doc_comments(&f.attrs);
1011                            if !doc_comment.is_empty() {
1012                                doc_comment
1013                            } else {
1014                                field_name.to_string()
1015                            }
1016                        };
1017
1018                        // Determine the value based on format_with attribute
1019                        let value_expr = if let Some(format_with) = attrs.format_with {
1020                            // Parse the function path string into a syn::Path
1021                            let func_path: syn::Path =
1022                                syn::parse_str(&format_with).unwrap_or_else(|_| {
1023                                    panic!("Invalid function path: {}", format_with)
1024                                });
1025                            quote! { #func_path(&self.#field_name) }
1026                        } else {
1027                            quote! { self.#field_name.to_prompt() }
1028                        };
1029
1030                        text_field_parts.push(quote! {
1031                            text_parts.push(format!("{}: {}", #key, #value_expr));
1032                        });
1033                    }
1034                }
1035
1036                // Generate the implementation with to_prompt_parts()
1037                quote! {
1038                    impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
1039                        fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
1040                            let mut parts = Vec::new();
1041
1042                            // Add image parts first
1043                            #(#image_field_parts)*
1044
1045                            // Collect text parts and add as a single text prompt part
1046                            let mut text_parts = Vec::new();
1047                            #(#text_field_parts)*
1048
1049                            if !text_parts.is_empty() {
1050                                parts.push(llm_toolkit::prompt::PromptPart::Text(text_parts.join("\n")));
1051                            }
1052
1053                            parts
1054                        }
1055
1056                        fn to_prompt(&self) -> String {
1057                            let mut text_parts = Vec::new();
1058                            #(#text_field_parts)*
1059                            text_parts.join("\n")
1060                        }
1061                    }
1062                }
1063            };
1064
1065            TokenStream::from(expanded)
1066        }
1067        Data::Union(_) => {
1068            panic!("`#[derive(ToPrompt)]` is not supported for unions");
1069        }
1070    }
1071}
1072
1073/// Information about a prompt target
1074#[derive(Debug, Clone)]
1075struct TargetInfo {
1076    name: String,
1077    template: Option<String>,
1078    field_configs: std::collections::HashMap<String, FieldTargetConfig>,
1079}
1080
1081/// Configuration for how a field should be handled for a specific target
1082#[derive(Debug, Clone, Default)]
1083struct FieldTargetConfig {
1084    skip: bool,
1085    rename: Option<String>,
1086    format_with: Option<String>,
1087    image: bool,
1088    include_only: bool, // true if this field is specifically included for this target
1089}
1090
1091/// Parse #[prompt_for(...)] attributes for ToPromptSet
1092fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
1093    let mut configs = Vec::new();
1094
1095    for attr in attrs {
1096        if attr.path().is_ident("prompt_for")
1097            && let Ok(meta_list) = attr.meta.require_list()
1098        {
1099            // Try to parse as meta list
1100            if meta_list.tokens.to_string() == "skip" {
1101                // Simple #[prompt_for(skip)] applies to all targets
1102                let config = FieldTargetConfig {
1103                    skip: true,
1104                    ..Default::default()
1105                };
1106                configs.push(("*".to_string(), config));
1107            } else if let Ok(metas) =
1108                meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1109            {
1110                let mut target_name = None;
1111                let mut config = FieldTargetConfig::default();
1112
1113                for meta in metas {
1114                    match meta {
1115                        Meta::NameValue(nv) if nv.path.is_ident("name") => {
1116                            if let syn::Expr::Lit(syn::ExprLit {
1117                                lit: syn::Lit::Str(lit_str),
1118                                ..
1119                            }) = nv.value
1120                            {
1121                                target_name = Some(lit_str.value());
1122                            }
1123                        }
1124                        Meta::Path(path) if path.is_ident("skip") => {
1125                            config.skip = true;
1126                        }
1127                        Meta::NameValue(nv) if nv.path.is_ident("rename") => {
1128                            if let syn::Expr::Lit(syn::ExprLit {
1129                                lit: syn::Lit::Str(lit_str),
1130                                ..
1131                            }) = nv.value
1132                            {
1133                                config.rename = Some(lit_str.value());
1134                            }
1135                        }
1136                        Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
1137                            if let syn::Expr::Lit(syn::ExprLit {
1138                                lit: syn::Lit::Str(lit_str),
1139                                ..
1140                            }) = nv.value
1141                            {
1142                                config.format_with = Some(lit_str.value());
1143                            }
1144                        }
1145                        Meta::Path(path) if path.is_ident("image") => {
1146                            config.image = true;
1147                        }
1148                        _ => {}
1149                    }
1150                }
1151
1152                if let Some(name) = target_name {
1153                    config.include_only = true;
1154                    configs.push((name, config));
1155                }
1156            }
1157        }
1158    }
1159
1160    configs
1161}
1162
1163/// Parse struct-level #[prompt_for(...)] attributes to find target templates
1164fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
1165    let mut targets = Vec::new();
1166
1167    for attr in attrs {
1168        if attr.path().is_ident("prompt_for")
1169            && let Ok(meta_list) = attr.meta.require_list()
1170            && let Ok(metas) =
1171                meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1172        {
1173            let mut target_name = None;
1174            let mut template = None;
1175
1176            for meta in metas {
1177                match meta {
1178                    Meta::NameValue(nv) if nv.path.is_ident("name") => {
1179                        if let syn::Expr::Lit(syn::ExprLit {
1180                            lit: syn::Lit::Str(lit_str),
1181                            ..
1182                        }) = nv.value
1183                        {
1184                            target_name = Some(lit_str.value());
1185                        }
1186                    }
1187                    Meta::NameValue(nv) if nv.path.is_ident("template") => {
1188                        if let syn::Expr::Lit(syn::ExprLit {
1189                            lit: syn::Lit::Str(lit_str),
1190                            ..
1191                        }) = nv.value
1192                        {
1193                            template = Some(lit_str.value());
1194                        }
1195                    }
1196                    _ => {}
1197                }
1198            }
1199
1200            if let Some(name) = target_name {
1201                targets.push(TargetInfo {
1202                    name,
1203                    template,
1204                    field_configs: std::collections::HashMap::new(),
1205                });
1206            }
1207        }
1208    }
1209
1210    targets
1211}
1212
1213#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
1214pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
1215    let input = parse_macro_input!(input as DeriveInput);
1216
1217    // Only support structs with named fields
1218    let data_struct = match &input.data {
1219        Data::Struct(data) => data,
1220        _ => {
1221            return syn::Error::new(
1222                input.ident.span(),
1223                "`#[derive(ToPromptSet)]` is only supported for structs",
1224            )
1225            .to_compile_error()
1226            .into();
1227        }
1228    };
1229
1230    let fields = match &data_struct.fields {
1231        syn::Fields::Named(fields) => &fields.named,
1232        _ => {
1233            return syn::Error::new(
1234                input.ident.span(),
1235                "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
1236            )
1237            .to_compile_error()
1238            .into();
1239        }
1240    };
1241
1242    // Parse struct-level attributes to find targets
1243    let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
1244
1245    // Parse field-level attributes
1246    for field in fields.iter() {
1247        let field_name = field.ident.as_ref().unwrap().to_string();
1248        let field_configs = parse_prompt_for_attrs(&field.attrs);
1249
1250        for (target_name, config) in field_configs {
1251            if target_name == "*" {
1252                // Apply to all targets
1253                for target in &mut targets {
1254                    target
1255                        .field_configs
1256                        .entry(field_name.clone())
1257                        .or_insert_with(FieldTargetConfig::default)
1258                        .skip = config.skip;
1259                }
1260            } else {
1261                // Find or create the target
1262                let target_exists = targets.iter().any(|t| t.name == target_name);
1263                if !target_exists {
1264                    // Add implicit target if not defined at struct level
1265                    targets.push(TargetInfo {
1266                        name: target_name.clone(),
1267                        template: None,
1268                        field_configs: std::collections::HashMap::new(),
1269                    });
1270                }
1271
1272                let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
1273
1274                target.field_configs.insert(field_name.clone(), config);
1275            }
1276        }
1277    }
1278
1279    // Generate match arms for each target
1280    let mut match_arms = Vec::new();
1281
1282    for target in &targets {
1283        let target_name = &target.name;
1284
1285        if let Some(template_str) = &target.template {
1286            // Template-based generation
1287            let mut image_parts = Vec::new();
1288
1289            for field in fields.iter() {
1290                let field_name = field.ident.as_ref().unwrap();
1291                let field_name_str = field_name.to_string();
1292
1293                if let Some(config) = target.field_configs.get(&field_name_str)
1294                    && config.image
1295                {
1296                    image_parts.push(quote! {
1297                        parts.extend(self.#field_name.to_prompt_parts());
1298                    });
1299                }
1300            }
1301
1302            match_arms.push(quote! {
1303                #target_name => {
1304                    let mut parts = Vec::new();
1305
1306                    #(#image_parts)*
1307
1308                    let text = llm_toolkit::prompt::render_prompt(#template_str, self)
1309                        .map_err(|e| llm_toolkit::prompt::PromptSetError::RenderFailed {
1310                            target: #target_name.to_string(),
1311                            source: e,
1312                        })?;
1313
1314                    if !text.is_empty() {
1315                        parts.push(llm_toolkit::prompt::PromptPart::Text(text));
1316                    }
1317
1318                    Ok(parts)
1319                }
1320            });
1321        } else {
1322            // Key-value based generation
1323            let mut text_field_parts = Vec::new();
1324            let mut image_field_parts = Vec::new();
1325
1326            for field in fields.iter() {
1327                let field_name = field.ident.as_ref().unwrap();
1328                let field_name_str = field_name.to_string();
1329
1330                // Check if field should be included for this target
1331                let config = target.field_configs.get(&field_name_str);
1332
1333                // Skip if explicitly marked to skip
1334                if let Some(cfg) = config
1335                    && cfg.skip
1336                {
1337                    continue;
1338                }
1339
1340                // For non-template targets, only include fields that are:
1341                // 1. Explicitly marked for this target with #[prompt_for(name = "Target")]
1342                // 2. Not marked for any specific target (default fields)
1343                let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
1344                let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
1345                    .iter()
1346                    .any(|(name, _)| name != "*");
1347
1348                if has_any_target_specific_config && !is_explicitly_for_this_target {
1349                    continue;
1350                }
1351
1352                if let Some(cfg) = config {
1353                    if cfg.image {
1354                        image_field_parts.push(quote! {
1355                            parts.extend(self.#field_name.to_prompt_parts());
1356                        });
1357                    } else {
1358                        let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
1359
1360                        let value_expr = if let Some(format_with) = &cfg.format_with {
1361                            // Parse the function path - if it fails, generate code that will produce a compile error
1362                            match syn::parse_str::<syn::Path>(format_with) {
1363                                Ok(func_path) => quote! { #func_path(&self.#field_name) },
1364                                Err(_) => {
1365                                    // Generate a compile error by using an invalid identifier
1366                                    let error_msg = format!(
1367                                        "Invalid function path in format_with: '{}'",
1368                                        format_with
1369                                    );
1370                                    quote! {
1371                                        compile_error!(#error_msg);
1372                                        String::new()
1373                                    }
1374                                }
1375                            }
1376                        } else {
1377                            quote! { self.#field_name.to_prompt() }
1378                        };
1379
1380                        text_field_parts.push(quote! {
1381                            text_parts.push(format!("{}: {}", #key, #value_expr));
1382                        });
1383                    }
1384                } else {
1385                    // Default handling for fields without specific config
1386                    text_field_parts.push(quote! {
1387                        text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
1388                    });
1389                }
1390            }
1391
1392            match_arms.push(quote! {
1393                #target_name => {
1394                    let mut parts = Vec::new();
1395
1396                    #(#image_field_parts)*
1397
1398                    let mut text_parts = Vec::new();
1399                    #(#text_field_parts)*
1400
1401                    if !text_parts.is_empty() {
1402                        parts.push(llm_toolkit::prompt::PromptPart::Text(text_parts.join("\n")));
1403                    }
1404
1405                    Ok(parts)
1406                }
1407            });
1408        }
1409    }
1410
1411    // Collect all target names for error reporting
1412    let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
1413
1414    // Add default case for unknown targets
1415    match_arms.push(quote! {
1416        _ => {
1417            let available = vec![#(#target_names.to_string()),*];
1418            Err(llm_toolkit::prompt::PromptSetError::TargetNotFound {
1419                target: target.to_string(),
1420                available,
1421            })
1422        }
1423    });
1424
1425    let struct_name = &input.ident;
1426    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1427
1428    let expanded = quote! {
1429        impl #impl_generics llm_toolkit::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
1430            fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<llm_toolkit::prompt::PromptPart>, llm_toolkit::prompt::PromptSetError> {
1431                match target {
1432                    #(#match_arms)*
1433                }
1434            }
1435        }
1436    };
1437
1438    TokenStream::from(expanded)
1439}
1440
1441/// Wrapper struct for parsing a comma-separated list of types
1442struct TypeList {
1443    types: Punctuated<syn::Type, Token![,]>,
1444}
1445
1446impl Parse for TypeList {
1447    fn parse(input: ParseStream) -> syn::Result<Self> {
1448        Ok(TypeList {
1449            types: Punctuated::parse_terminated(input)?,
1450        })
1451    }
1452}
1453
1454/// Generates a formatted Markdown examples section for the provided types.
1455///
1456/// This macro accepts a comma-separated list of types and generates a single
1457/// formatted Markdown string containing examples of each type.
1458///
1459/// # Example
1460///
1461/// ```rust,ignore
1462/// let examples = examples_section!(User, Concept);
1463/// // Produces a string like:
1464/// // ---
1465/// // ### Examples
1466/// //
1467/// // Here are examples of the data structures you should use.
1468/// //
1469/// // ---
1470/// // #### `User`
1471/// // {...json...}
1472/// // ---
1473/// // #### `Concept`
1474/// // {...json...}
1475/// // ---
1476/// ```
1477#[proc_macro]
1478pub fn examples_section(input: TokenStream) -> TokenStream {
1479    let input = parse_macro_input!(input as TypeList);
1480
1481    // Generate code for each type
1482    let mut type_sections = Vec::new();
1483
1484    for ty in input.types.iter() {
1485        // Extract the type name as a string
1486        let type_name_str = quote!(#ty).to_string();
1487
1488        // Generate the section for this type
1489        type_sections.push(quote! {
1490            {
1491                let type_name = #type_name_str;
1492                let json_example = <#ty as Default>::default().to_prompt_with_mode("example_only");
1493                format!("---\n#### `{}`\n{}", type_name, json_example)
1494            }
1495        });
1496    }
1497
1498    // Build the complete examples string
1499    let expanded = quote! {
1500        {
1501            let mut sections = Vec::new();
1502            sections.push("---".to_string());
1503            sections.push("### Examples".to_string());
1504            sections.push("".to_string());
1505            sections.push("Here are examples of the data structures you should use.".to_string());
1506            sections.push("".to_string());
1507
1508            #(sections.push(#type_sections);)*
1509
1510            sections.push("---".to_string());
1511
1512            sections.join("\n")
1513        }
1514    };
1515
1516    TokenStream::from(expanded)
1517}
1518
1519/// Helper function to parse struct-level #[prompt_for(target = "...", template = "...")] attribute
1520fn parse_to_prompt_for_attribute(attrs: &[syn::Attribute]) -> (syn::Type, String) {
1521    for attr in attrs {
1522        if attr.path().is_ident("prompt_for")
1523            && let Ok(meta_list) = attr.meta.require_list()
1524            && let Ok(metas) =
1525                meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1526        {
1527            let mut target_type = None;
1528            let mut template = None;
1529
1530            for meta in metas {
1531                match meta {
1532                    Meta::NameValue(nv) if nv.path.is_ident("target") => {
1533                        if let syn::Expr::Lit(syn::ExprLit {
1534                            lit: syn::Lit::Str(lit_str),
1535                            ..
1536                        }) = nv.value
1537                        {
1538                            // Parse the type string into a syn::Type
1539                            target_type = syn::parse_str::<syn::Type>(&lit_str.value()).ok();
1540                        }
1541                    }
1542                    Meta::NameValue(nv) if nv.path.is_ident("template") => {
1543                        if let syn::Expr::Lit(syn::ExprLit {
1544                            lit: syn::Lit::Str(lit_str),
1545                            ..
1546                        }) = nv.value
1547                        {
1548                            template = Some(lit_str.value());
1549                        }
1550                    }
1551                    _ => {}
1552                }
1553            }
1554
1555            if let (Some(target), Some(tmpl)) = (target_type, template) {
1556                return (target, tmpl);
1557            }
1558        }
1559    }
1560
1561    panic!("ToPromptFor requires #[prompt_for(target = \"TargetType\", template = \"...\")]");
1562}
1563
1564/// A procedural attribute macro that generates prompt-building functions and extractor structs for intent enums.
1565///
1566/// This macro should be applied to an enum to generate:
1567/// 1. A prompt-building function that incorporates enum documentation
1568/// 2. An extractor struct that implements `IntentExtractor`
1569///
1570/// # Requirements
1571///
1572/// The enum must have an `#[intent(...)]` attribute with:
1573/// - `prompt`: The prompt template (supports Jinja-style variables)
1574/// - `extractor_tag`: The tag to use for extraction
1575///
1576/// # Example
1577///
1578/// ```rust,ignore
1579/// #[define_intent]
1580/// #[intent(
1581///     prompt = "Analyze the intent: {{ user_input }}",
1582///     extractor_tag = "intent"
1583/// )]
1584/// enum MyIntent {
1585///     /// Create a new item
1586///     Create,
1587///     /// Update an existing item
1588///     Update,
1589///     /// Delete an item
1590///     Delete,
1591/// }
1592/// ```
1593///
1594/// This will generate:
1595/// - `pub fn build_my_intent_prompt(user_input: &str) -> String`
1596/// - `pub struct MyIntentExtractor;` with `IntentExtractor<MyIntent>` implementation
1597#[proc_macro_attribute]
1598pub fn define_intent(_attr: TokenStream, item: TokenStream) -> TokenStream {
1599    let input = parse_macro_input!(item as DeriveInput);
1600
1601    // Verify this is an enum
1602    let enum_data = match &input.data {
1603        Data::Enum(data) => data,
1604        _ => {
1605            return syn::Error::new(
1606                input.ident.span(),
1607                "`#[define_intent]` can only be applied to enums",
1608            )
1609            .to_compile_error()
1610            .into();
1611        }
1612    };
1613
1614    // Parse the #[intent(...)] attribute
1615    let mut prompt_template = None;
1616    let mut extractor_tag = None;
1617
1618    for attr in &input.attrs {
1619        if attr.path().is_ident("intent")
1620            && let Ok(metas) =
1621                attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1622        {
1623            for meta in metas {
1624                match meta {
1625                    Meta::NameValue(nv) if nv.path.is_ident("prompt") => {
1626                        if let syn::Expr::Lit(syn::ExprLit {
1627                            lit: syn::Lit::Str(lit_str),
1628                            ..
1629                        }) = nv.value
1630                        {
1631                            prompt_template = Some(lit_str.value());
1632                        }
1633                    }
1634                    Meta::NameValue(nv) if nv.path.is_ident("extractor_tag") => {
1635                        if let syn::Expr::Lit(syn::ExprLit {
1636                            lit: syn::Lit::Str(lit_str),
1637                            ..
1638                        }) = nv.value
1639                        {
1640                            extractor_tag = Some(lit_str.value());
1641                        }
1642                    }
1643                    _ => {}
1644                }
1645            }
1646        }
1647    }
1648
1649    // Validate required attributes
1650    let prompt_template = match prompt_template {
1651        Some(p) => p,
1652        None => {
1653            return syn::Error::new(
1654                input.ident.span(),
1655                "`#[intent(...)]` attribute must include `prompt = \"...\"`",
1656            )
1657            .to_compile_error()
1658            .into();
1659        }
1660    };
1661
1662    let extractor_tag = match extractor_tag {
1663        Some(t) => t,
1664        None => {
1665            return syn::Error::new(
1666                input.ident.span(),
1667                "`#[intent(...)]` attribute must include `extractor_tag = \"...\"`",
1668            )
1669            .to_compile_error()
1670            .into();
1671        }
1672    };
1673
1674    // Generate the intents documentation
1675    let enum_name = &input.ident;
1676    let enum_docs = extract_doc_comments(&input.attrs);
1677
1678    let mut intents_doc_lines = Vec::new();
1679
1680    // Add enum description if present
1681    if !enum_docs.is_empty() {
1682        intents_doc_lines.push(format!("{}: {}", enum_name, enum_docs));
1683    } else {
1684        intents_doc_lines.push(format!("{}:", enum_name));
1685    }
1686    intents_doc_lines.push(String::new()); // Empty line
1687    intents_doc_lines.push("Possible values:".to_string());
1688
1689    // Add each variant with its documentation
1690    for variant in &enum_data.variants {
1691        let variant_name = &variant.ident;
1692        let variant_docs = extract_doc_comments(&variant.attrs);
1693
1694        if !variant_docs.is_empty() {
1695            intents_doc_lines.push(format!("- {}: {}", variant_name, variant_docs));
1696        } else {
1697            intents_doc_lines.push(format!("- {}", variant_name));
1698        }
1699    }
1700
1701    let intents_doc_str = intents_doc_lines.join("\n");
1702
1703    // Parse template variables (excluding intents_doc which we'll inject)
1704    let placeholders = parse_template_placeholders(&prompt_template);
1705    let user_variables: Vec<String> = placeholders
1706        .iter()
1707        .filter_map(|(name, _)| {
1708            if name != "intents_doc" {
1709                Some(name.clone())
1710            } else {
1711                None
1712            }
1713        })
1714        .collect();
1715
1716    // Generate function name (snake_case)
1717    let enum_name_str = enum_name.to_string();
1718    let snake_case_name = to_snake_case(&enum_name_str);
1719    let function_name = syn::Ident::new(
1720        &format!("build_{}_prompt", snake_case_name),
1721        proc_macro2::Span::call_site(),
1722    );
1723
1724    // Generate function parameters (all &str for simplicity)
1725    let function_params: Vec<proc_macro2::TokenStream> = user_variables
1726        .iter()
1727        .map(|var| {
1728            let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
1729            quote! { #ident: &str }
1730        })
1731        .collect();
1732
1733    // Generate context insertions
1734    let context_insertions: Vec<proc_macro2::TokenStream> = user_variables
1735        .iter()
1736        .map(|var| {
1737            let var_str = var.clone();
1738            let ident = syn::Ident::new(var, proc_macro2::Span::call_site());
1739            quote! {
1740                __template_context.insert(#var_str.to_string(), minijinja::Value::from(#ident));
1741            }
1742        })
1743        .collect();
1744
1745    // Convert template to minijinja syntax
1746    let converted_template = convert_to_minijinja_syntax(&prompt_template);
1747
1748    // Generate extractor struct name
1749    let extractor_name = syn::Ident::new(
1750        &format!("{}Extractor", enum_name),
1751        proc_macro2::Span::call_site(),
1752    );
1753
1754    // Filter out the #[intent(...)] attribute from the enum attributes
1755    let filtered_attrs: Vec<_> = input
1756        .attrs
1757        .iter()
1758        .filter(|attr| !attr.path().is_ident("intent"))
1759        .collect();
1760
1761    // Rebuild the enum with filtered attributes
1762    let vis = &input.vis;
1763    let generics = &input.generics;
1764    let variants = &enum_data.variants;
1765    let enum_output = quote! {
1766        #(#filtered_attrs)*
1767        #vis enum #enum_name #generics {
1768            #variants
1769        }
1770    };
1771
1772    // Generate the complete output
1773    let expanded = quote! {
1774        // Output the enum without the #[intent(...)] attribute
1775        #enum_output
1776
1777        // Generate the prompt-building function
1778        pub fn #function_name(#(#function_params),*) -> String {
1779            let mut env = minijinja::Environment::new();
1780            env.add_template("prompt", #converted_template)
1781                .expect("Failed to parse intent prompt template");
1782
1783            let tmpl = env.get_template("prompt").unwrap();
1784
1785            let mut __template_context = std::collections::HashMap::new();
1786
1787            // Add intents_doc
1788            __template_context.insert("intents_doc".to_string(), minijinja::Value::from(#intents_doc_str));
1789
1790            // Add user-provided variables
1791            #(#context_insertions)*
1792
1793            tmpl.render(&__template_context)
1794                .unwrap_or_else(|e| format!("Failed to render intent prompt: {}", e))
1795        }
1796
1797        // Generate the extractor struct
1798        pub struct #extractor_name;
1799
1800        impl #extractor_name {
1801            pub const EXTRACTOR_TAG: &'static str = #extractor_tag;
1802        }
1803
1804        impl llm_toolkit::intent::IntentExtractor<#enum_name> for #extractor_name {
1805            fn extract_intent(&self, response: &str) -> Result<#enum_name, llm_toolkit::intent::IntentExtractionError> {
1806                // Use the common extraction function with our tag
1807                llm_toolkit::intent::extract_intent_from_response(response, Self::EXTRACTOR_TAG)
1808            }
1809        }
1810    };
1811
1812    TokenStream::from(expanded)
1813}
1814
1815/// Convert PascalCase to snake_case
1816fn to_snake_case(s: &str) -> String {
1817    let mut result = String::new();
1818    let mut prev_upper = false;
1819
1820    for (i, ch) in s.chars().enumerate() {
1821        if ch.is_uppercase() {
1822            if i > 0 && !prev_upper {
1823                result.push('_');
1824            }
1825            result.push(ch.to_lowercase().next().unwrap());
1826            prev_upper = true;
1827        } else {
1828            result.push(ch);
1829            prev_upper = false;
1830        }
1831    }
1832
1833    result
1834}
1835
1836/// Derives the `ToPromptFor` trait for a struct
1837#[proc_macro_derive(ToPromptFor, attributes(prompt_for))]
1838pub fn to_prompt_for_derive(input: TokenStream) -> TokenStream {
1839    let input = parse_macro_input!(input as DeriveInput);
1840
1841    // Parse the struct-level prompt_for attribute
1842    let (target_type, template) = parse_to_prompt_for_attribute(&input.attrs);
1843
1844    let struct_name = &input.ident;
1845    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1846
1847    // Parse the template to find placeholders
1848    let placeholders = parse_template_placeholders(&template);
1849
1850    // Convert template to minijinja syntax and build context generation code
1851    let mut converted_template = template.clone();
1852    let mut context_fields = Vec::new();
1853
1854    // Get struct fields for validation
1855    let fields = match &input.data {
1856        Data::Struct(data_struct) => match &data_struct.fields {
1857            syn::Fields::Named(fields) => &fields.named,
1858            _ => panic!("ToPromptFor is only supported for structs with named fields"),
1859        },
1860        _ => panic!("ToPromptFor is only supported for structs"),
1861    };
1862
1863    // Check if the struct has mode support (has #[prompt(mode = ...)] attribute)
1864    let has_mode_support = input.attrs.iter().any(|attr| {
1865        if attr.path().is_ident("prompt")
1866            && let Ok(metas) =
1867                attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
1868        {
1869            for meta in metas {
1870                if let Meta::NameValue(nv) = meta
1871                    && nv.path.is_ident("mode")
1872                {
1873                    return true;
1874                }
1875            }
1876        }
1877        false
1878    });
1879
1880    // Process each placeholder
1881    for (placeholder_name, mode_opt) in &placeholders {
1882        if placeholder_name == "self" {
1883            if let Some(specific_mode) = mode_opt {
1884                // {self:some_mode} - use a unique key
1885                let unique_key = format!("self__{}", specific_mode);
1886
1887                // Replace {self:mode} with {{self__mode}} in template
1888                let pattern = format!("{{self:{}}}", specific_mode);
1889                let replacement = format!("{{{{{}}}}}", unique_key);
1890                converted_template = converted_template.replace(&pattern, &replacement);
1891
1892                // Add to context with the specific mode
1893                context_fields.push(quote! {
1894                    context.insert(
1895                        #unique_key.to_string(),
1896                        minijinja::Value::from(self.to_prompt_with_mode(#specific_mode))
1897                    );
1898                });
1899            } else {
1900                // {self} - behavior depends on whether the struct has mode support
1901                let pattern = "{self}";
1902                let replacement = "{{self}}";
1903                converted_template = converted_template.replace(pattern, replacement);
1904
1905                if has_mode_support {
1906                    // If the struct has mode support, use to_prompt_with_mode with the mode parameter
1907                    context_fields.push(quote! {
1908                        context.insert(
1909                            "self".to_string(),
1910                            minijinja::Value::from(self.to_prompt_with_mode(mode))
1911                        );
1912                    });
1913                } else {
1914                    // If the struct doesn't have mode support, use to_prompt() which gives key-value format
1915                    context_fields.push(quote! {
1916                        context.insert(
1917                            "self".to_string(),
1918                            minijinja::Value::from(self.to_prompt())
1919                        );
1920                    });
1921                }
1922            }
1923        } else {
1924            // It's a field placeholder
1925            // Check if the field exists
1926            let field_exists = fields.iter().any(|f| {
1927                f.ident
1928                    .as_ref()
1929                    .is_some_and(|ident| ident == placeholder_name)
1930            });
1931
1932            if field_exists {
1933                let field_ident = syn::Ident::new(placeholder_name, proc_macro2::Span::call_site());
1934
1935                // Replace {field} with {{field}} in template
1936                let pattern = format!("{{{}}}", placeholder_name);
1937                let replacement = format!("{{{{{}}}}}", placeholder_name);
1938                converted_template = converted_template.replace(&pattern, &replacement);
1939
1940                // Add field to context - serialize the field value
1941                context_fields.push(quote! {
1942                    context.insert(
1943                        #placeholder_name.to_string(),
1944                        minijinja::Value::from_serialize(&self.#field_ident)
1945                    );
1946                });
1947            }
1948            // If field doesn't exist, we'll let minijinja handle the error at runtime
1949        }
1950    }
1951
1952    let expanded = quote! {
1953        impl #impl_generics llm_toolkit::prompt::ToPromptFor<#target_type> for #struct_name #ty_generics #where_clause
1954        where
1955            #target_type: serde::Serialize,
1956        {
1957            fn to_prompt_for_with_mode(&self, target: &#target_type, mode: &str) -> String {
1958                // Create minijinja environment and add template
1959                let mut env = minijinja::Environment::new();
1960                env.add_template("prompt", #converted_template).unwrap_or_else(|e| {
1961                    panic!("Failed to parse template: {}", e)
1962                });
1963
1964                let tmpl = env.get_template("prompt").unwrap();
1965
1966                // Build context
1967                let mut context = std::collections::HashMap::new();
1968                // Add self to the context for field access in templates
1969                context.insert(
1970                    "self".to_string(),
1971                    minijinja::Value::from_serialize(self)
1972                );
1973                // Add target to the context
1974                context.insert(
1975                    "target".to_string(),
1976                    minijinja::Value::from_serialize(target)
1977                );
1978                #(#context_fields)*
1979
1980                // Render template
1981                tmpl.render(context).unwrap_or_else(|e| {
1982                    format!("Failed to render prompt: {}", e)
1983                })
1984            }
1985        }
1986    };
1987
1988    TokenStream::from(expanded)
1989}