batch_mode_token_expansion_axis_derive/
lib.rs1#![forbid(unsafe_code)]
8#![deny(clippy::unwrap_used)]
9#![deny(clippy::expect_used)]
10extern crate proc_macro;
13
14#[macro_use] mod imports; use imports::*;
15
16xp!{system_message_goal}
17xp!{axis_attribute}
18xp!{try_parse_name_value}
19xp!{try_parse_parenthesized}
20
21#[proc_macro_derive(TokenExpansionAxis, attributes(system_message_goal, axis))]
22pub fn derive_token_expander_axis(input: TokenStream) -> TokenStream {
23 use syn::spanned::Spanned;
24
25 trace!("Parsing input for TokenExpansionAxis derive macro.");
26 let ast = parse_macro_input!(input as DeriveInput);
27
28 let enum_ident = &ast.ident;
29 let enum_name_str = enum_ident.to_string();
30 debug!("Found enum name: {}", enum_name_str);
31
32 let expanded_struct_name_str = if let Some(stripped) = enum_name_str.strip_suffix("ExpanderAxis") {
35 format!("Expanded{}", stripped)
36 } else {
37 format!("Expanded{}", enum_name_str)
38 };
39 let aggregator_name_str = if let Some(stripped) = enum_name_str.strip_suffix("ExpanderAxis") {
40 format!("{}Expander", stripped)
41 } else {
42 format!("{}Expander", enum_name_str)
43 };
44
45 let expanded_struct_ident = syn::Ident::new(&expanded_struct_name_str, enum_ident.span());
46 let aggregator_ident = syn::Ident::new(&aggregator_name_str, enum_ident.span());
47
48 let data_enum = match &ast.data {
50 syn::Data::Enum(e) => {
51 trace!("Confirmed that {} is an enum.", enum_name_str);
52 e
53 }
54 _ => {
55 let err = syn::Error::new_spanned(enum_ident, "This macro only supports enums.");
56 error!("Encountered a non-enum type for #[derive(TokenExpansionAxis)]: {}", err);
57 return err.to_compile_error().into();
58 }
59 };
60
61 let system_message_goal = match parse_system_message_goal(&ast.attrs) {
63 Ok(Some(lit_str)) => {
64 debug!("system_message_goal attribute found: {:?}", lit_str.value());
65 lit_str
66 }
67 Ok(None) => {
68 debug!("No system_message_goal found; using default fallback.");
69 syn::LitStr::new("Default system message goal", enum_ident.span())
70 }
71 Err(err) => {
72 error!("Error parsing system_message_goal: {}", err);
73 return err.to_compile_error().into();
74 }
75 };
76
77 let mut axis_name_matches = Vec::new();
80 let mut axis_desc_matches = Vec::new();
81 let mut variant_idents = Vec::new();
82 let mut struct_fields = Vec::new();
83
84 for variant in &data_enum.variants {
85 let variant_ident = &variant.ident;
86 trace!("Processing variant: {}", variant_ident);
87
88 variant_idents.push(variant_ident);
89
90 let (axis_name, axis_desc) = variant
92 .attrs
93 .iter()
94 .filter_map(|attr| parse_axis_attribute(attr).ok().flatten())
95 .next()
96 .unwrap_or_else(|| {
97 panic!(
98 "Missing #[axis(\"axis_name => axis_description\")] on variant '{}'",
99 variant_ident
100 );
101 });
102
103 axis_name_matches.push(quote! {
105 #enum_ident::#variant_ident => std::borrow::Cow::Borrowed(#axis_name)
106 });
107 axis_desc_matches.push(quote! {
108 #enum_ident::#variant_ident => std::borrow::Cow::Borrowed(#axis_desc)
109 });
110
111 let field_name_str = axis_name.to_snake_case();
113 let field_ident = syn::Ident::new(&field_name_str, variant_ident.span());
114 let field_ty: syn::Type = syn::parse_str("Option<Vec<String>>").unwrap();
115
116 let field = quote! {
117 #[serde(default)]
118 #field_ident : #field_ty
119 };
120 struct_fields.push(field);
121 }
122
123 let expanded_impl_axis_traits = quote! {
125 impl TokenExpansionAxis for #enum_ident {}
126
127 impl AxisName for #enum_ident {
128 fn axis_name(&self) -> std::borrow::Cow<'_, str> {
129 match self {
130 #( #axis_name_matches ),*
131 }
132 }
133 }
134
135 impl AxisDescription for #enum_ident {
136 fn axis_description(&self) -> std::borrow::Cow<'_, str> {
137 match self {
138 #( #axis_desc_matches ),*
139 }
140 }
141 }
142 };
143
144 let expanded_struct = quote! {
146 #[derive(NamedItem,SaveLoad,AiJsonTemplate,Getters, Debug, Clone, Serialize, Deserialize)]
147 #[getset(get="pub")]
148 pub struct #expanded_struct_ident {
149 #[serde(alias = "token_name")]
150 name: String,
151 #( #struct_fields ),*
152 }
153
154 unsafe impl Send for #expanded_struct_ident {}
155 unsafe impl Sync for #expanded_struct_ident {}
156
157 #[async_trait::async_trait]
158 impl ExpandedToken for #expanded_struct_ident {
159 type Expander = #aggregator_ident;
160 }
161 };
162
163 let aggregator_struct = quote! {
185 #[derive(Debug, Clone)]
186 pub struct #aggregator_ident;
187
188 impl Default for #aggregator_ident {
189 fn default() -> Self {
190 Self
191 }
192 }
193 };
194
195 let named_impl = quote! {
197 impl Named for #aggregator_ident {
198 fn name(&self) -> std::borrow::Cow<'_, str> {
199 std::borrow::Cow::Borrowed(#aggregator_name_str)
202 }
203 }
204 };
205
206 let system_message_goal_impl = quote! {
208 impl SystemMessageGoal for #aggregator_ident {
209 fn system_message_goal(&self) -> std::borrow::Cow<'_, str> {
210 std::borrow::Cow::Borrowed(#system_message_goal)
211 }
212 }
213 };
214
215 let variant_arms = variant_idents.iter().map(|v| {
217 quote! { std::sync::Arc::new(#enum_ident::#v) }
218 });
219
220 let token_expander_impl = quote! {
221 impl TokenExpander for #aggregator_ident {}
222
223 impl GetTokenExpansionAxes for #aggregator_ident {
224 fn axes(&self) -> TokenExpansionAxes {
225 vec![
226 #( #variant_arms ),*
227 ]
228 }
229 }
230 };
231
232 let expanded = quote! {
234 use batch_mode_token_expansion_traits::*;
235
236 #expanded_impl_axis_traits
237 #expanded_struct
238 #aggregator_struct
241 #named_impl
242 #system_message_goal_impl
243 #token_expander_impl
244 };
245
246 trace!("Macro expansion for {} completed.", enum_name_str);
247 TokenStream::from(expanded)
248}