batch_mode_token_expansion_axis_derive/
lib.rs

1// ---------------- [ File: batch-mode-token-expansion-axis-derive/src/lib.rs ]
2//! The `token-expander-axis-derive` crate defines the custom `#[derive(TokenExpansionAxis)]`
3//! procedural macro. It reads the enum, extracts the `#[system_message_goal("...")]`
4//! and `#[axis("axis_name" => "axis_description")]` attributes, and implements
5//! the `SystemMessageGoal`, `AxisName`, and `AxisDescription` traits from the
6//! `token-expander-axis` crate.
7#![forbid(unsafe_code)]
8#![deny(clippy::unwrap_used)]
9#![deny(clippy::expect_used)]
10//#![deny(missing_docs)]
11
12extern crate proc_macro;
13
14#[macro_use] mod imports; use imports::*;
15
16xp!{system_message_goal}
17xp!{axis_attribute}
18xp!{try_parse_name_value}
19xp!{try_parse_parenthesized}
20
21#[proc_macro_derive(TokenExpansionAxis, attributes(system_message_goal, axis))]
22pub fn derive_token_expander_axis(input: TokenStream) -> TokenStream {
23    use syn::spanned::Spanned;
24
25    trace!("Parsing input for TokenExpansionAxis derive macro.");
26    let ast = parse_macro_input!(input as DeriveInput);
27
28    let enum_ident = &ast.ident;
29    let enum_name_str = enum_ident.to_string();
30    debug!("Found enum name: {}", enum_name_str);
31
32    // We’ll strip “ExpanderAxis” from the end of the enum name and
33    // create the expanded struct + aggregator with the usual naming strategy.
34    let expanded_struct_name_str = if let Some(stripped) = enum_name_str.strip_suffix("ExpanderAxis") {
35        format!("Expanded{}", stripped)
36    } else {
37        format!("Expanded{}", enum_name_str)
38    };
39    let aggregator_name_str = if let Some(stripped) = enum_name_str.strip_suffix("ExpanderAxis") {
40        format!("{}Expander", stripped)
41    } else {
42        format!("{}Expander", enum_name_str)
43    };
44
45    let expanded_struct_ident = syn::Ident::new(&expanded_struct_name_str, enum_ident.span());
46    let aggregator_ident = syn::Ident::new(&aggregator_name_str, enum_ident.span());
47
48    // Ensure this derive only applies to enums
49    let data_enum = match &ast.data {
50        syn::Data::Enum(e) => {
51            trace!("Confirmed that {} is an enum.", enum_name_str);
52            e
53        }
54        _ => {
55            let err = syn::Error::new_spanned(enum_ident, "This macro only supports enums.");
56            error!("Encountered a non-enum type for #[derive(TokenExpansionAxis)]: {}", err);
57            return err.to_compile_error().into();
58        }
59    };
60
61    // Attempt to parse any system_message_goal attribute on the enum
62    let system_message_goal = match parse_system_message_goal(&ast.attrs) {
63        Ok(Some(lit_str)) => {
64            debug!("system_message_goal attribute found: {:?}", lit_str.value());
65            lit_str
66        }
67        Ok(None) => {
68            debug!("No system_message_goal found; using default fallback.");
69            syn::LitStr::new("Default system message goal", enum_ident.span())
70        }
71        Err(err) => {
72            error!("Error parsing system_message_goal: {}", err);
73            return err.to_compile_error().into();
74        }
75    };
76
77    // Gather data for implementing AxisName/AxisDescription,
78    // plus building the aggregator’s axes() and the struct fields.
79    let mut axis_name_matches = Vec::new();
80    let mut axis_desc_matches = Vec::new();
81    let mut variant_idents    = Vec::new();
82    let mut struct_fields     = Vec::new();
83
84    for variant in &data_enum.variants {
85        let variant_ident = &variant.ident;
86        trace!("Processing variant: {}", variant_ident);
87
88        variant_idents.push(variant_ident);
89
90        // We expect each variant to have exactly one #[axis("... => ...")]
91        let (axis_name, axis_desc) = variant
92            .attrs
93            .iter()
94            .filter_map(|attr| parse_axis_attribute(attr).ok().flatten())
95            .next()
96            .unwrap_or_else(|| {
97                panic!(
98                    "Missing #[axis(\"axis_name => axis_description\")] on variant '{}'",
99                    variant_ident
100                );
101            });
102
103        // For AxisName / AxisDescription matching:
104        axis_name_matches.push(quote! {
105            #enum_ident::#variant_ident => std::borrow::Cow::Borrowed(#axis_name)
106        });
107        axis_desc_matches.push(quote! {
108            #enum_ident::#variant_ident => std::borrow::Cow::Borrowed(#axis_desc)
109        });
110
111        // Generate a field in the expanded struct (snake_case from the axis name)
112        let field_name_str = axis_name.to_snake_case();
113        let field_ident = syn::Ident::new(&field_name_str, variant_ident.span());
114        let field_ty: syn::Type = syn::parse_str("Option<Vec<String>>").unwrap();
115
116        let field = quote! {
117            #[serde(default)]
118            #field_ident : #field_ty
119        };
120        struct_fields.push(field);
121    }
122
123    // 1) Implement the axis traits on the original enum
124    let expanded_impl_axis_traits = quote! {
125        impl TokenExpansionAxis for #enum_ident {}
126
127        impl AxisName for #enum_ident {
128            fn axis_name(&self) -> std::borrow::Cow<'_, str> {
129                match self {
130                    #( #axis_name_matches ),*
131                }
132            }
133        }
134
135        impl AxisDescription for #enum_ident {
136            fn axis_description(&self) -> std::borrow::Cow<'_, str> {
137                match self {
138                    #( #axis_desc_matches ),*
139                }
140            }
141        }
142    };
143
144    // 2) The data-carrying struct: Expanded{Enum}
145    let expanded_struct = quote! {
146        #[derive(NamedItem,SaveLoad,AiJsonTemplate,Getters, Debug, Clone, Serialize, Deserialize)]
147        #[getset(get="pub")]
148        pub struct #expanded_struct_ident {
149            #[serde(alias = "token_name")]
150            name: String,
151            #( #struct_fields ),*
152        }
153
154        unsafe impl Send for #expanded_struct_ident {}
155        unsafe impl Sync for #expanded_struct_ident {}
156
157        #[async_trait::async_trait]
158        impl ExpandedToken for #expanded_struct_ident {
159            type Expander = #aggregator_ident;
160        }
161    };
162
163    /*
164    // 2a) Implement LoadFromFile if needed
165    let load_from_file_impl = quote! {
166        #[async_trait::async_trait]
167        impl LoadFromFile for #expanded_struct_ident {
168            type Error = SaveLoadError;
169
170            async fn load_from_file(filename: impl AsRef<std::path::Path> + Send)
171                -> Result<Self, Self::Error>
172            {
173                info!("loading token expansion from file {:?}", filename.as_ref());
174                let json = std::fs::read_to_string(filename)?;
175                let this = serde_json::from_str(&json)
176                    .map_err(|e| SaveLoadError::JsonParseError(e.into()))?;
177                Ok(this)
178            }
179        }
180    };
181    */
182
183    // 3) The aggregator struct: e.g. EnumExpander
184    let aggregator_struct = quote! {
185        #[derive(Debug, Clone)]
186        pub struct #aggregator_ident;
187
188        impl Default for #aggregator_ident {
189            fn default() -> Self {
190                Self
191            }
192        }
193    };
194
195    // 3a) Named for aggregator => return the aggregator’s own name
196    let named_impl = quote! {
197        impl Named for #aggregator_ident {
198            fn name(&self) -> std::borrow::Cow<'_, str> {
199                // IMPORTANT: We actually want the aggregator name here,
200                // which is aggregator_name_str, not the enum's name.
201                std::borrow::Cow::Borrowed(#aggregator_name_str)
202            }
203        }
204    };
205
206    // 3b) SystemMessageGoal => use the parsed or default system_message_goal
207    let system_message_goal_impl = quote! {
208        impl SystemMessageGoal for #aggregator_ident {
209            fn system_message_goal(&self) -> std::borrow::Cow<'_, str> {
210                std::borrow::Cow::Borrowed(#system_message_goal)
211            }
212        }
213    };
214
215    // 3c) TokenExpander + GetTokenExpansionAxes => aggregator’s axes
216    let variant_arms = variant_idents.iter().map(|v| {
217        quote! { std::sync::Arc::new(#enum_ident::#v) }
218    });
219
220    let token_expander_impl = quote! {
221        impl TokenExpander for #aggregator_ident {}
222
223        impl GetTokenExpansionAxes for #aggregator_ident {
224            fn axes(&self) -> TokenExpansionAxes {
225                vec![
226                    #( #variant_arms ),*
227                ]
228            }
229        }
230    };
231
232    // Combine everything
233    let expanded = quote! {
234        use batch_mode_token_expansion_traits::*;
235
236        #expanded_impl_axis_traits
237        #expanded_struct
238        //#load_from_file_impl
239
240        #aggregator_struct
241        #named_impl
242        #system_message_goal_impl
243        #token_expander_impl
244    };
245
246    trace!("Macro expansion for {} completed.", enum_name_str);
247    TokenStream::from(expanded)
248}