1use quote::{ToTokens, quote};
2use proc_macro::TokenStream;
3use syn::{parse_macro_input, parse_quote, AttributeArgs, DataEnum, DeriveInput, Meta, MetaList, NestedMeta};
4
5
6
7
8fn strum_enum(input: &DeriveInput, attr_args: &[NestedMeta]) -> proc_macro2::TokenStream {
9 let ident = &input.ident;
10
11 let name_arg = attr_args.iter().find_map(|meta| {
13 if let NestedMeta::Meta(Meta::List(MetaList { path, nested, .. })) = meta {
14 if path.is_ident("name") {
15 return Some(quote! { name(#nested) });
16 }
17 }
18 None
19 });
20
21 let maybe_name = if let Some(name) = name_arg {
22 quote! { #name, }
23 } else {
24 quote! {}
25 };
26
27
28 quote! {
29 #[derive(
30 ::std::fmt::Debug,
31 ::std::clone::Clone,
32 ::std::cmp::PartialEq,
33 ::saa_schema::strum_macros::Display,
34 ::saa_schema::strum_macros::EnumDiscriminants,
35 ::saa_schema::strum_macros::VariantNames,
36 ::saa_schema::serde::Serialize,
37 ::saa_schema::serde::Deserialize,
38 ::saa_schema::schemars::JsonSchema,
39 )]
40 #[strum_discriminants(
41 #maybe_name
42 derive(
43 ::saa_schema::serde::Serialize,
44 ::saa_schema::serde::Deserialize,
45 ::saa_schema::schemars::JsonSchema,
46 ::saa_schema::strum_macros::Display,
47 ::saa_schema::strum_macros::EnumString,
48 ::saa_schema::strum_macros::VariantArray,
49 ::saa_schema::strum_macros::AsRefStr
50 ),
51 serde(deny_unknown_fields, rename_all = "snake_case", crate = "::saa_schema::serde"),
52 strum(serialize_all = "snake_case", crate = "::saa_schema::strum"),
53 schemars(crate = "::saa_schema::schemars")
54 )]
55 #[strum(serialize_all = "snake_case", crate = "::saa_schema::strum")]
56 #[serde(deny_unknown_fields, rename_all = "snake_case", crate = "::saa_schema::serde")]
57 #[schemars(crate = "::saa_schema::schemars")]
58 #[allow(clippy::derive_partial_eq_without_eq)]
59 #input
60
61 impl ::saa_schema::strum::IntoDiscriminant for Box<#ident> {
62 type Discriminant = <#ident as ::saa_schema::strum::IntoDiscriminant>::Discriminant;
63 fn discriminant(&self) -> Self::Discriminant {
64 (*self).discriminant()
65 }
66 }
67
68 }
69}
70
71
72
73
74
75
76fn merge_enum_variants(
77 metadata: TokenStream,
78 left_ts: TokenStream,
79 right_ts: TokenStream,
80) -> TokenStream {
81 use syn::Data::Enum;
82
83 let args = parse_macro_input!(metadata as AttributeArgs);
85 if let Some(first_arg) = args.first() {
86 return syn::Error::new_spanned(first_arg, "macro takes no arguments")
87 .to_compile_error()
88 .into();
89 }
90
91 let mut left: DeriveInput = parse_macro_input!(left_ts);
93 let variants = match &mut left.data {
94 syn::Data::Enum(DataEnum { variants, .. }) => variants,
95 _ => return syn::Error::new(left.ident.span(), "only enums can accept variants")
96 .to_compile_error()
97 .into(),
98 };
99
100 let right: DeriveInput = parse_macro_input!(right_ts);
102 let Enum(DataEnum { variants: to_add, .. }) = right.data else {
103 return syn::Error::new(left.ident.span(), "only enums can provide variants")
104 .to_compile_error()
105 .into();
106 };
107
108 variants.extend(to_add.into_iter());
110
111 left.into_token_stream().into()
113}
114
115
116
117
118
119fn generate_session_macro<F>(
120 metadata: TokenStream,
121 input: TokenStream,
122 right_enum: TokenStream,
123 extra_impl: F,
124 extra_attrs: Option<Vec<syn::Attribute>>,
125) -> TokenStream
126where
127 F: Fn(&syn::Ident, &syn::Generics, &proc_macro2::TokenStream, &proc_macro2::TokenStream, Option<&syn::WhereClause>) -> proc_macro2::TokenStream,
128{
129 let merged = merge_enum_variants(metadata, input, right_enum);
130 let mut parsed = match syn::parse::<DeriveInput>(merged.clone()) {
132 Ok(val) => val,
133 Err(err) => return err.to_compile_error().into(),
134 };
135
136
137 if let Some(extra) = extra_attrs {
139 parsed.attrs.extend(extra);
140 }
141
142 let enum_name = &parsed.ident;
143
144 let generics = &parsed.generics;
145 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
146
147 let common_impl = strum_enum(&parsed, &[]);
148
149 let custom_impl = extra_impl(
150 enum_name,
151 generics,
152 "e! { #impl_generics },
153 "e! { #ty_generics },
154 where_clause,
155 );
156
157 quote! {
158 #common_impl
159 #custom_impl
160 }
161 .into()
162}
163
164
165
166
167
168
169#[proc_macro_attribute]
170pub fn session_action(metadata: TokenStream, input: TokenStream) -> TokenStream {
171 generate_session_macro(
172 metadata,
173 input,
174 quote! {
175 enum SessionRight {
176 SessionActions(Box<::cw_auth::SessionActionMsg<Self>>),
177 }
178 }
179 .into(),
180 |enum_name, _generics, impl_generics, ty_generics, where_clause| {
181 quote! {
182 impl #impl_generics ::cw_auth::SessionActionsMatch for #enum_name #ty_generics #where_clause {
183 fn match_actions(&self) -> Option<::cw_auth::SessionActionMsg<Self>> {
184 match self {
185 Self::SessionActions(msg) => Some((**msg).clone()),
186 _ => None,
187 }
188 }
189 }
190 }
191 },
192 None,
193 )
194}
195
196
197
198
199
200#[proc_macro_attribute]
201pub fn session_query(metadata: TokenStream, input: TokenStream) -> TokenStream {
202 let args = parse_macro_input!(metadata as AttributeArgs);
203
204 if args.len() != 1 {
206 return syn::Error::new_spanned(
207 quote! { #[session_query(..)] },
208 "expected #[session_query(ExecuteMsg)] with exactly one argument",
209 )
210 .to_compile_error()
211 .into();
212 }
213
214 let base_msg_ident = match &args[0] {
216 syn::NestedMeta::Meta(syn::Meta::Path(path)) => match path.get_ident() {
217 Some(ident) => ident.clone(),
218 None => {
219 return syn::Error::new_spanned(
220 path,
221 "expected identifier like `ExecuteMsg`"
222 )
223 .to_compile_error()
224 .into();
225 }
226 },
227 other => {
228 return syn::Error::new_spanned(
229 other,
230 "expected identifier like `ExecuteMsg`"
231 )
232 .to_compile_error()
233 .into();
234 }
235 };
236
237
238 let extra_attrs = Some(vec![parse_quote! {
239 #[derive(::saa_schema::QueryResponses)]
240 }]);
241
242 generate_session_macro(
244 TokenStream::new(),
245 input,
246 quote! {
247 enum SessionRight {
248 #[returns(::cw_auth::QueryResTemplate)]
249 SessionQueries(Box<::cw_auth::SessionQueryMsg<Self>>),
250 }
251 }
252 .into(),
253 move |enum_name, _generics, impl_generics, ty_generics, where_clause| {
254 let base_msg = &base_msg_ident;
255 quote! {
256 impl #impl_generics ::cw_auth::SessionQueriesMatch for #enum_name #ty_generics #where_clause {
257 fn match_queries(&self) -> Option<::cw_auth::SessionQueryMsg<Self>> {
258 match self {
259 Self::SessionQueries(msg) => Some((**msg).clone()),
260 _ => None,
261 }
262 }
263 }
264 impl #impl_generics ::cw_auth::QueryUsesActions for #enum_name #ty_generics #where_clause {
265 type ActionMsg = #base_msg;
266 }
267 }
268 },
269 extra_attrs,
270 )
271}
272
273
274