cw_auth_protos/
lib.rs

1use quote::{ToTokens, quote};
2use proc_macro::TokenStream;
3use syn::{parse_macro_input, parse_quote, AttributeArgs, DataEnum, DeriveInput, Meta, MetaList, NestedMeta};
4
5
6
7
8fn strum_enum(input: &DeriveInput, attr_args: &[NestedMeta]) -> proc_macro2::TokenStream {
9    let ident = &input.ident;
10
11    // Extract optional name(...) argument
12    let name_arg = attr_args.iter().find_map(|meta| {
13        if let NestedMeta::Meta(Meta::List(MetaList { path, nested, .. })) = meta {
14            if path.is_ident("name") {
15                return Some(quote! { name(#nested) });
16            }
17        }
18        None
19    });
20
21    let maybe_name = if let Some(name) = name_arg {
22        quote! { #name, }
23    } else {
24        quote! {}
25    };
26
27
28    quote! {
29        #[derive(
30            ::std::fmt::Debug,
31            ::std::clone::Clone,
32            ::std::cmp::PartialEq,
33            ::saa_schema::strum_macros::Display,
34            ::saa_schema::strum_macros::EnumDiscriminants,
35            ::saa_schema::strum_macros::VariantNames,
36            ::saa_schema::serde::Serialize,
37            ::saa_schema::serde::Deserialize,
38            ::saa_schema::schemars::JsonSchema,
39        )]
40        #[strum_discriminants(
41            #maybe_name
42            derive(
43                ::saa_schema::serde::Serialize,
44                ::saa_schema::serde::Deserialize,
45                ::saa_schema::schemars::JsonSchema,
46                ::saa_schema::strum_macros::Display,
47                ::saa_schema::strum_macros::EnumString,
48                ::saa_schema::strum_macros::VariantArray,
49                ::saa_schema::strum_macros::AsRefStr
50            ),
51            serde(deny_unknown_fields, rename_all = "snake_case", crate = "::saa_schema::serde"),
52            strum(serialize_all = "snake_case", crate = "::saa_schema::strum"),
53            schemars(crate = "::saa_schema::schemars")
54        )]
55        #[strum(serialize_all = "snake_case", crate = "::saa_schema::strum")]
56        #[serde(deny_unknown_fields, rename_all = "snake_case", crate = "::saa_schema::serde")]
57        #[schemars(crate = "::saa_schema::schemars")]
58        #[allow(clippy::derive_partial_eq_without_eq)]
59        #input
60
61        impl ::saa_schema::strum::IntoDiscriminant for Box<#ident> {
62            type Discriminant = <#ident as ::saa_schema::strum::IntoDiscriminant>::Discriminant;
63            fn discriminant(&self) -> Self::Discriminant {
64                (*self).discriminant()
65            }
66        }
67
68    }
69}
70
71
72
73
74
75
76fn merge_enum_variants(
77    metadata: TokenStream,
78    left_ts: TokenStream,
79    right_ts: TokenStream,
80) -> TokenStream {
81    use syn::Data::Enum;
82
83    // Parse metadata and check no args
84    let args = parse_macro_input!(metadata as AttributeArgs);
85    if let Some(first_arg) = args.first() {
86        return syn::Error::new_spanned(first_arg, "macro takes no arguments")
87            .to_compile_error()
88            .into();
89    }
90
91    // Parse left and ensure it's enum
92    let mut left: DeriveInput = parse_macro_input!(left_ts);
93    let variants = match &mut left.data {
94        syn::Data::Enum(DataEnum { variants, .. }) => variants,
95        _ => return syn::Error::new(left.ident.span(), "only enums can accept variants")
96            .to_compile_error()
97            .into(),
98    };
99
100    // Parse right and ensure it's enum
101    let right: DeriveInput = parse_macro_input!(right_ts);
102    let Enum(DataEnum { variants: to_add, .. }) = right.data else {
103        return syn::Error::new(left.ident.span(), "only enums can provide variants")
104            .to_compile_error()
105            .into();
106    };
107
108    // Merge variants
109    variants.extend(to_add.into_iter());
110
111    // Return modified left
112    left.into_token_stream().into()
113}
114
115
116
117
118
119fn generate_session_macro<F>(
120    metadata: TokenStream,
121    input: TokenStream,
122    right_enum: TokenStream,
123    extra_impl: F,
124    extra_attrs: Option<Vec<syn::Attribute>>,
125) -> TokenStream
126where
127    F: Fn(&syn::Ident, &syn::Generics, &proc_macro2::TokenStream, &proc_macro2::TokenStream, Option<&syn::WhereClause>) -> proc_macro2::TokenStream,
128{
129    let merged = merge_enum_variants(metadata, input, right_enum);
130    // Try to parse the merged stream back into DeriveInput
131    let mut parsed = match syn::parse::<DeriveInput>(merged.clone()) {
132        Ok(val) => val,
133        Err(err) => return err.to_compile_error().into(),
134    };
135
136    
137    // If extra attributes were provided, extend them on the parsed item
138    if let Some(extra) = extra_attrs {
139        parsed.attrs.extend(extra);
140    }
141    
142    let enum_name = &parsed.ident;
143
144    let generics = &parsed.generics;
145    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
146
147    let common_impl = strum_enum(&parsed, &[]);
148
149    let custom_impl = extra_impl(
150        enum_name,
151        generics,
152        &quote! { #impl_generics },
153        &quote! { #ty_generics },
154        where_clause,
155    );
156
157    quote! {
158        #common_impl
159        #custom_impl
160    }
161    .into()
162}
163
164
165
166
167
168
169#[proc_macro_attribute]
170pub fn session_action(metadata: TokenStream, input: TokenStream) -> TokenStream {
171    generate_session_macro(
172        metadata,
173        input,
174        quote! {
175            enum SessionRight {
176                SessionActions(Box<::cw_auths::SessionActionMsg<Self>>),
177            }
178        }
179        .into(),
180        |enum_name, _generics, impl_generics, ty_generics, where_clause| {
181            quote! {
182                impl #impl_generics ::cw_auths::SessionActionsMatch for #enum_name #ty_generics #where_clause {
183                    fn match_actions(&self) -> Option<::cw_auths::SessionActionMsg<Self>> {
184                        match self {
185                            Self::SessionActions(msg) => Some((**msg).clone()),
186                            _ => None,
187                        }
188                    }
189                }
190            }
191        },
192        None,
193    )
194}
195
196
197
198
199
200#[proc_macro_attribute]
201pub fn session_query(metadata: TokenStream, input: TokenStream) -> TokenStream {
202    let args = parse_macro_input!(metadata as AttributeArgs);
203
204    // Ensure exactly one argument
205    if args.len() != 1 {
206        return syn::Error::new_spanned(
207            quote! { #[session_query(..)] },
208            "expected #[session_query(ExecuteMsg)] with exactly one argument",
209        )
210        .to_compile_error()
211        .into();
212    }
213
214    // Extract identifier (e.g., ExecuteMsg)
215    let base_msg_ident = match &args[0] {
216        syn::NestedMeta::Meta(syn::Meta::Path(path)) => match path.get_ident() {
217            Some(ident) => ident.clone(),
218            None => {
219                return syn::Error::new_spanned(
220                    path,
221                    "expected identifier like `ExecuteMsg`"
222                )
223                .to_compile_error()
224                .into();
225            }
226        },
227        other => {
228            return syn::Error::new_spanned(
229                other,
230                "expected identifier like `ExecuteMsg`"
231            )
232            .to_compile_error()
233            .into();
234        }
235    };
236
237
238    let extra_attrs = Some(vec![parse_quote! {
239        #[derive(::saa_schema::QueryResponses)]
240    }]);
241
242    // Proceed as before
243    generate_session_macro(
244        TokenStream::new(),
245        input,
246        quote! {
247            enum SessionRight {
248                #[returns(::cw_auths::QueryResTemplate)]
249                SessionQueries(Box<::cw_auths::SessionQueryMsg<Self>>),
250            }
251        }
252        .into(),
253        move |enum_name, _generics, impl_generics, ty_generics, where_clause| {
254            let base_msg = &base_msg_ident;
255            quote! {
256                impl #impl_generics ::cw_auths::SessionQueriesMatch for #enum_name #ty_generics #where_clause {
257                    fn match_queries(&self) -> Option<::cw_auths::SessionQueryMsg<Self>> {
258                        match self {
259                            Self::SessionQueries(msg) => Some((**msg).clone()),
260                            _ => None,
261                        }
262                    }
263                }
264                impl #impl_generics ::cw_auths::QueryUsesActions for #enum_name #ty_generics #where_clause {
265                    type ActionMsg = #base_msg;
266                }
267            }
268        },
269        extra_attrs,
270    )
271}
272
273
274