llm_toolkit_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Meta, parse_macro_input, punctuated::Punctuated};
4
5/// Extract doc comments from attributes
6fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
7    attrs
8        .iter()
9        .filter_map(|attr| {
10            if attr.path().is_ident("doc")
11                && let syn::Meta::NameValue(meta_name_value) = &attr.meta
12                && let syn::Expr::Lit(syn::ExprLit {
13                    lit: syn::Lit::Str(lit_str),
14                    ..
15                }) = &meta_name_value.value
16            {
17                return Some(lit_str.value());
18            }
19            None
20        })
21        .map(|s| s.trim().to_string())
22        .collect::<Vec<_>>()
23        .join(" ")
24}
25
26/// Result of parsing prompt attribute
27enum PromptAttribute {
28    Skip,
29    Description(String),
30    None,
31}
32
33/// Parse #[prompt(...)] attribute on enum variant
34fn parse_prompt_attribute(attrs: &[syn::Attribute]) -> PromptAttribute {
35    for attr in attrs {
36        if attr.path().is_ident("prompt") {
37            // Check for #[prompt(skip)]
38            if let Ok(meta_list) = attr.meta.require_list() {
39                let tokens = &meta_list.tokens;
40                let tokens_str = tokens.to_string();
41                if tokens_str == "skip" {
42                    return PromptAttribute::Skip;
43                }
44            }
45
46            // Check for #[prompt("description")]
47            if let Ok(lit_str) = attr.parse_args::<syn::LitStr>() {
48                return PromptAttribute::Description(lit_str.value());
49            }
50        }
51    }
52    PromptAttribute::None
53}
54
55/// Parsed field-level prompt attributes
56#[derive(Debug, Default)]
57struct FieldPromptAttrs {
58    skip: bool,
59    rename: Option<String>,
60    format_with: Option<String>,
61    image: bool,
62}
63
64/// Parse #[prompt(...)] attributes for struct fields
65fn parse_field_prompt_attrs(attrs: &[syn::Attribute]) -> FieldPromptAttrs {
66    let mut result = FieldPromptAttrs::default();
67
68    for attr in attrs {
69        if attr.path().is_ident("prompt") {
70            // Try to parse as meta list #[prompt(key = value, ...)]
71            if let Ok(meta_list) = attr.meta.require_list() {
72                // Parse the tokens inside the parentheses
73                if let Ok(metas) =
74                    meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
75                {
76                    for meta in metas {
77                        match meta {
78                            Meta::Path(path) if path.is_ident("skip") => {
79                                result.skip = true;
80                            }
81                            Meta::NameValue(nv) if nv.path.is_ident("rename") => {
82                                if let syn::Expr::Lit(syn::ExprLit {
83                                    lit: syn::Lit::Str(lit_str),
84                                    ..
85                                }) = nv.value
86                                {
87                                    result.rename = Some(lit_str.value());
88                                }
89                            }
90                            Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
91                                if let syn::Expr::Lit(syn::ExprLit {
92                                    lit: syn::Lit::Str(lit_str),
93                                    ..
94                                }) = nv.value
95                                {
96                                    result.format_with = Some(lit_str.value());
97                                }
98                            }
99                            Meta::Path(path) if path.is_ident("image") => {
100                                result.image = true;
101                            }
102                            _ => {}
103                        }
104                    }
105                } else if meta_list.tokens.to_string() == "skip" {
106                    // Handle simple #[prompt(skip)] case
107                    result.skip = true;
108                } else if meta_list.tokens.to_string() == "image" {
109                    // Handle simple #[prompt(image)] case
110                    result.image = true;
111                }
112            }
113        }
114    }
115
116    result
117}
118
119#[proc_macro_derive(ToPrompt, attributes(prompt))]
120pub fn to_prompt_derive(input: TokenStream) -> TokenStream {
121    let input = parse_macro_input!(input as DeriveInput);
122
123    // Check if this is a struct or enum
124    match &input.data {
125        Data::Enum(data_enum) => {
126            // For enums, generate prompt from doc comments
127            let enum_name = &input.ident;
128            let enum_docs = extract_doc_comments(&input.attrs);
129
130            let mut prompt_lines = Vec::new();
131
132            // Add enum description
133            if !enum_docs.is_empty() {
134                prompt_lines.push(format!("{}: {}", enum_name, enum_docs));
135            } else {
136                prompt_lines.push(format!("{}:", enum_name));
137            }
138            prompt_lines.push(String::new()); // Empty line
139            prompt_lines.push("Possible values:".to_string());
140
141            // Add each variant with its documentation based on priority
142            for variant in &data_enum.variants {
143                let variant_name = &variant.ident;
144
145                // Apply fallback logic with priority
146                match parse_prompt_attribute(&variant.attrs) {
147                    PromptAttribute::Skip => {
148                        // Skip this variant completely
149                        continue;
150                    }
151                    PromptAttribute::Description(desc) => {
152                        // Use custom description from #[prompt("...")]
153                        prompt_lines.push(format!("- {}: {}", variant_name, desc));
154                    }
155                    PromptAttribute::None => {
156                        // Fall back to doc comment or just variant name
157                        let variant_docs = extract_doc_comments(&variant.attrs);
158                        if !variant_docs.is_empty() {
159                            prompt_lines.push(format!("- {}: {}", variant_name, variant_docs));
160                        } else {
161                            prompt_lines.push(format!("- {}", variant_name));
162                        }
163                    }
164                }
165            }
166
167            let prompt_string = prompt_lines.join("\n");
168            let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
169
170            let expanded = quote! {
171                impl #impl_generics llm_toolkit::prompt::ToPrompt for #enum_name #ty_generics #where_clause {
172                    fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
173                        vec![llm_toolkit::prompt::PromptPart::Text(#prompt_string.to_string())]
174                    }
175
176                    fn to_prompt(&self) -> String {
177                        #prompt_string.to_string()
178                    }
179                }
180            };
181
182            TokenStream::from(expanded)
183        }
184        Data::Struct(data_struct) => {
185            // Check if there's a #[prompt(template = "...")] attribute
186            let template_attr = input
187                .attrs
188                .iter()
189                .find(|attr| attr.path().is_ident("prompt"))
190                .and_then(|attr| {
191                    // Try to parse the attribute arguments
192                    attr.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
193                        .ok()
194                        .and_then(|metas| {
195                            metas.into_iter().find_map(|meta| match meta {
196                                Meta::NameValue(nv) if nv.path.is_ident("template") => {
197                                    if let syn::Expr::Lit(expr_lit) = nv.value {
198                                        if let syn::Lit::Str(lit_str) = expr_lit.lit {
199                                            Some(lit_str.value())
200                                        } else {
201                                            None
202                                        }
203                                    } else {
204                                        None
205                                    }
206                                }
207                                _ => None,
208                            })
209                        })
210                });
211
212            let name = input.ident;
213            let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
214
215            let expanded = if let Some(template_str) = template_attr {
216                // Use template-based approach if template is provided
217                // Collect image fields separately for to_prompt_parts()
218                let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
219                    &fields.named
220                } else {
221                    panic!(
222                        "Template prompt generation is only supported for structs with named fields."
223                    );
224                };
225
226                let mut image_field_parts = Vec::new();
227                for f in fields.iter() {
228                    let field_name = f.ident.as_ref().unwrap();
229                    let attrs = parse_field_prompt_attrs(&f.attrs);
230
231                    if attrs.image {
232                        // This field is marked as an image
233                        image_field_parts.push(quote! {
234                            parts.extend(self.#field_name.to_prompt_parts());
235                        });
236                    }
237                }
238
239                quote! {
240                    impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
241                        fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
242                            let mut parts = Vec::new();
243
244                            // Add image parts first
245                            #(#image_field_parts)*
246
247                            // Add the rendered template as text
248                            let text = llm_toolkit::prompt::render_prompt(#template_str, self).unwrap_or_else(|e| {
249                                format!("Failed to render prompt: {}", e)
250                            });
251                            if !text.is_empty() {
252                                parts.push(llm_toolkit::prompt::PromptPart::Text(text));
253                            }
254
255                            parts
256                        }
257
258                        fn to_prompt(&self) -> String {
259                            llm_toolkit::prompt::render_prompt(#template_str, self).unwrap_or_else(|e| {
260                                format!("Failed to render prompt: {}", e)
261                            })
262                        }
263                    }
264                }
265            } else {
266                // Use default key-value format if no template is provided
267                // Now also generate to_prompt_parts() for multimodal support
268                let fields = if let syn::Fields::Named(fields) = &data_struct.fields {
269                    &fields.named
270                } else {
271                    panic!(
272                        "Default prompt generation is only supported for structs with named fields."
273                    );
274                };
275
276                // Separate image fields from text fields
277                let mut text_field_parts = Vec::new();
278                let mut image_field_parts = Vec::new();
279
280                for f in fields.iter() {
281                    let field_name = f.ident.as_ref().unwrap();
282                    let attrs = parse_field_prompt_attrs(&f.attrs);
283
284                    // Skip if #[prompt(skip)] is present
285                    if attrs.skip {
286                        continue;
287                    }
288
289                    if attrs.image {
290                        // This field is marked as an image
291                        image_field_parts.push(quote! {
292                            parts.extend(self.#field_name.to_prompt_parts());
293                        });
294                    } else {
295                        // This is a regular text field
296                        // Determine the key based on priority:
297                        // 1. #[prompt(rename = "new_name")]
298                        // 2. Doc comment
299                        // 3. Field name (fallback)
300                        let key = if let Some(rename) = attrs.rename {
301                            rename
302                        } else {
303                            let doc_comment = extract_doc_comments(&f.attrs);
304                            if !doc_comment.is_empty() {
305                                doc_comment
306                            } else {
307                                field_name.to_string()
308                            }
309                        };
310
311                        // Determine the value based on format_with attribute
312                        let value_expr = if let Some(format_with) = attrs.format_with {
313                            // Parse the function path string into a syn::Path
314                            let func_path: syn::Path =
315                                syn::parse_str(&format_with).unwrap_or_else(|_| {
316                                    panic!("Invalid function path: {}", format_with)
317                                });
318                            quote! { #func_path(&self.#field_name) }
319                        } else {
320                            quote! { self.#field_name.to_prompt() }
321                        };
322
323                        text_field_parts.push(quote! {
324                            text_parts.push(format!("{}: {}", #key, #value_expr));
325                        });
326                    }
327                }
328
329                // Generate the implementation with to_prompt_parts()
330                quote! {
331                    impl #impl_generics llm_toolkit::prompt::ToPrompt for #name #ty_generics #where_clause {
332                        fn to_prompt_parts(&self) -> Vec<llm_toolkit::prompt::PromptPart> {
333                            let mut parts = Vec::new();
334
335                            // Add image parts first
336                            #(#image_field_parts)*
337
338                            // Collect text parts and add as a single text prompt part
339                            let mut text_parts = Vec::new();
340                            #(#text_field_parts)*
341
342                            if !text_parts.is_empty() {
343                                parts.push(llm_toolkit::prompt::PromptPart::Text(text_parts.join("\n")));
344                            }
345
346                            parts
347                        }
348
349                        fn to_prompt(&self) -> String {
350                            let mut text_parts = Vec::new();
351                            #(#text_field_parts)*
352                            text_parts.join("\n")
353                        }
354                    }
355                }
356            };
357
358            TokenStream::from(expanded)
359        }
360        Data::Union(_) => {
361            panic!("`#[derive(ToPrompt)]` is not supported for unions");
362        }
363    }
364}
365
366/// Information about a prompt target
367#[derive(Debug, Clone)]
368struct TargetInfo {
369    name: String,
370    template: Option<String>,
371    field_configs: std::collections::HashMap<String, FieldTargetConfig>,
372}
373
374/// Configuration for how a field should be handled for a specific target
375#[derive(Debug, Clone, Default)]
376struct FieldTargetConfig {
377    skip: bool,
378    rename: Option<String>,
379    format_with: Option<String>,
380    image: bool,
381    include_only: bool, // true if this field is specifically included for this target
382}
383
384/// Parse #[prompt_for(...)] attributes for ToPromptSet
385fn parse_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<(String, FieldTargetConfig)> {
386    let mut configs = Vec::new();
387
388    for attr in attrs {
389        if attr.path().is_ident("prompt_for")
390            && let Ok(meta_list) = attr.meta.require_list()
391        {
392            // Try to parse as meta list
393            if meta_list.tokens.to_string() == "skip" {
394                // Simple #[prompt_for(skip)] applies to all targets
395                let config = FieldTargetConfig {
396                    skip: true,
397                    ..Default::default()
398                };
399                configs.push(("*".to_string(), config));
400            } else if let Ok(metas) =
401                meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
402            {
403                let mut target_name = None;
404                let mut config = FieldTargetConfig::default();
405
406                for meta in metas {
407                    match meta {
408                        Meta::NameValue(nv) if nv.path.is_ident("name") => {
409                            if let syn::Expr::Lit(syn::ExprLit {
410                                lit: syn::Lit::Str(lit_str),
411                                ..
412                            }) = nv.value
413                            {
414                                target_name = Some(lit_str.value());
415                            }
416                        }
417                        Meta::Path(path) if path.is_ident("skip") => {
418                            config.skip = true;
419                        }
420                        Meta::NameValue(nv) if nv.path.is_ident("rename") => {
421                            if let syn::Expr::Lit(syn::ExprLit {
422                                lit: syn::Lit::Str(lit_str),
423                                ..
424                            }) = nv.value
425                            {
426                                config.rename = Some(lit_str.value());
427                            }
428                        }
429                        Meta::NameValue(nv) if nv.path.is_ident("format_with") => {
430                            if let syn::Expr::Lit(syn::ExprLit {
431                                lit: syn::Lit::Str(lit_str),
432                                ..
433                            }) = nv.value
434                            {
435                                config.format_with = Some(lit_str.value());
436                            }
437                        }
438                        Meta::Path(path) if path.is_ident("image") => {
439                            config.image = true;
440                        }
441                        _ => {}
442                    }
443                }
444
445                if let Some(name) = target_name {
446                    config.include_only = true;
447                    configs.push((name, config));
448                }
449            }
450        }
451    }
452
453    configs
454}
455
456/// Parse struct-level #[prompt_for(...)] attributes to find target templates
457fn parse_struct_prompt_for_attrs(attrs: &[syn::Attribute]) -> Vec<TargetInfo> {
458    let mut targets = Vec::new();
459
460    for attr in attrs {
461        if attr.path().is_ident("prompt_for")
462            && let Ok(meta_list) = attr.meta.require_list()
463            && let Ok(metas) =
464                meta_list.parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
465        {
466            let mut target_name = None;
467            let mut template = None;
468
469            for meta in metas {
470                match meta {
471                    Meta::NameValue(nv) if nv.path.is_ident("name") => {
472                        if let syn::Expr::Lit(syn::ExprLit {
473                            lit: syn::Lit::Str(lit_str),
474                            ..
475                        }) = nv.value
476                        {
477                            target_name = Some(lit_str.value());
478                        }
479                    }
480                    Meta::NameValue(nv) if nv.path.is_ident("template") => {
481                        if let syn::Expr::Lit(syn::ExprLit {
482                            lit: syn::Lit::Str(lit_str),
483                            ..
484                        }) = nv.value
485                        {
486                            template = Some(lit_str.value());
487                        }
488                    }
489                    _ => {}
490                }
491            }
492
493            if let Some(name) = target_name {
494                targets.push(TargetInfo {
495                    name,
496                    template,
497                    field_configs: std::collections::HashMap::new(),
498                });
499            }
500        }
501    }
502
503    targets
504}
505
506#[proc_macro_derive(ToPromptSet, attributes(prompt_for))]
507pub fn to_prompt_set_derive(input: TokenStream) -> TokenStream {
508    let input = parse_macro_input!(input as DeriveInput);
509
510    // Only support structs with named fields
511    let data_struct = match &input.data {
512        Data::Struct(data) => data,
513        _ => {
514            return syn::Error::new(
515                input.ident.span(),
516                "`#[derive(ToPromptSet)]` is only supported for structs",
517            )
518            .to_compile_error()
519            .into();
520        }
521    };
522
523    let fields = match &data_struct.fields {
524        syn::Fields::Named(fields) => &fields.named,
525        _ => {
526            return syn::Error::new(
527                input.ident.span(),
528                "`#[derive(ToPromptSet)]` is only supported for structs with named fields",
529            )
530            .to_compile_error()
531            .into();
532        }
533    };
534
535    // Parse struct-level attributes to find targets
536    let mut targets = parse_struct_prompt_for_attrs(&input.attrs);
537
538    // Parse field-level attributes
539    for field in fields.iter() {
540        let field_name = field.ident.as_ref().unwrap().to_string();
541        let field_configs = parse_prompt_for_attrs(&field.attrs);
542
543        for (target_name, config) in field_configs {
544            if target_name == "*" {
545                // Apply to all targets
546                for target in &mut targets {
547                    target
548                        .field_configs
549                        .entry(field_name.clone())
550                        .or_insert_with(FieldTargetConfig::default)
551                        .skip = config.skip;
552                }
553            } else {
554                // Find or create the target
555                let target_exists = targets.iter().any(|t| t.name == target_name);
556                if !target_exists {
557                    // Add implicit target if not defined at struct level
558                    targets.push(TargetInfo {
559                        name: target_name.clone(),
560                        template: None,
561                        field_configs: std::collections::HashMap::new(),
562                    });
563                }
564
565                let target = targets.iter_mut().find(|t| t.name == target_name).unwrap();
566
567                target.field_configs.insert(field_name.clone(), config);
568            }
569        }
570    }
571
572    // Generate match arms for each target
573    let mut match_arms = Vec::new();
574
575    for target in &targets {
576        let target_name = &target.name;
577
578        if let Some(template_str) = &target.template {
579            // Template-based generation
580            let mut image_parts = Vec::new();
581
582            for field in fields.iter() {
583                let field_name = field.ident.as_ref().unwrap();
584                let field_name_str = field_name.to_string();
585
586                if let Some(config) = target.field_configs.get(&field_name_str)
587                    && config.image
588                {
589                    image_parts.push(quote! {
590                        parts.extend(self.#field_name.to_prompt_parts());
591                    });
592                }
593            }
594
595            match_arms.push(quote! {
596                #target_name => {
597                    let mut parts = Vec::new();
598
599                    #(#image_parts)*
600
601                    let text = llm_toolkit::prompt::render_prompt(#template_str, self)
602                        .map_err(|e| llm_toolkit::prompt::PromptSetError::RenderFailed {
603                            target: #target_name.to_string(),
604                            source: e,
605                        })?;
606
607                    if !text.is_empty() {
608                        parts.push(llm_toolkit::prompt::PromptPart::Text(text));
609                    }
610
611                    Ok(parts)
612                }
613            });
614        } else {
615            // Key-value based generation
616            let mut text_field_parts = Vec::new();
617            let mut image_field_parts = Vec::new();
618
619            for field in fields.iter() {
620                let field_name = field.ident.as_ref().unwrap();
621                let field_name_str = field_name.to_string();
622
623                // Check if field should be included for this target
624                let config = target.field_configs.get(&field_name_str);
625
626                // Skip if explicitly marked to skip
627                if let Some(cfg) = config
628                    && cfg.skip
629                {
630                    continue;
631                }
632
633                // For non-template targets, only include fields that are:
634                // 1. Explicitly marked for this target with #[prompt_for(name = "Target")]
635                // 2. Not marked for any specific target (default fields)
636                let is_explicitly_for_this_target = config.is_some_and(|c| c.include_only);
637                let has_any_target_specific_config = parse_prompt_for_attrs(&field.attrs)
638                    .iter()
639                    .any(|(name, _)| name != "*");
640
641                if has_any_target_specific_config && !is_explicitly_for_this_target {
642                    continue;
643                }
644
645                if let Some(cfg) = config {
646                    if cfg.image {
647                        image_field_parts.push(quote! {
648                            parts.extend(self.#field_name.to_prompt_parts());
649                        });
650                    } else {
651                        let key = cfg.rename.clone().unwrap_or_else(|| field_name_str.clone());
652
653                        let value_expr = if let Some(format_with) = &cfg.format_with {
654                            // Parse the function path - if it fails, generate code that will produce a compile error
655                            match syn::parse_str::<syn::Path>(format_with) {
656                                Ok(func_path) => quote! { #func_path(&self.#field_name) },
657                                Err(_) => {
658                                    // Generate a compile error by using an invalid identifier
659                                    let error_msg = format!(
660                                        "Invalid function path in format_with: '{}'",
661                                        format_with
662                                    );
663                                    quote! {
664                                        compile_error!(#error_msg);
665                                        String::new()
666                                    }
667                                }
668                            }
669                        } else {
670                            quote! { self.#field_name.to_prompt() }
671                        };
672
673                        text_field_parts.push(quote! {
674                            text_parts.push(format!("{}: {}", #key, #value_expr));
675                        });
676                    }
677                } else {
678                    // Default handling for fields without specific config
679                    text_field_parts.push(quote! {
680                        text_parts.push(format!("{}: {}", #field_name_str, self.#field_name.to_prompt()));
681                    });
682                }
683            }
684
685            match_arms.push(quote! {
686                #target_name => {
687                    let mut parts = Vec::new();
688
689                    #(#image_field_parts)*
690
691                    let mut text_parts = Vec::new();
692                    #(#text_field_parts)*
693
694                    if !text_parts.is_empty() {
695                        parts.push(llm_toolkit::prompt::PromptPart::Text(text_parts.join("\n")));
696                    }
697
698                    Ok(parts)
699                }
700            });
701        }
702    }
703
704    // Collect all target names for error reporting
705    let target_names: Vec<String> = targets.iter().map(|t| t.name.clone()).collect();
706
707    // Add default case for unknown targets
708    match_arms.push(quote! {
709        _ => {
710            let available = vec![#(#target_names.to_string()),*];
711            Err(llm_toolkit::prompt::PromptSetError::TargetNotFound {
712                target: target.to_string(),
713                available,
714            })
715        }
716    });
717
718    let struct_name = &input.ident;
719    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
720
721    let expanded = quote! {
722        impl #impl_generics llm_toolkit::prompt::ToPromptSet for #struct_name #ty_generics #where_clause {
723            fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<llm_toolkit::prompt::PromptPart>, llm_toolkit::prompt::PromptSetError> {
724                match target {
725                    #(#match_arms)*
726                }
727            }
728        }
729    };
730
731    TokenStream::from(expanded)
732}