Skip to main content

adaptivemsg_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::quote;
4use syn::parse::Parser;
5use syn::{parse_macro_input, Fields, ItemImpl, ItemStruct, LitStr};
6
7fn compile_error<T: quote::ToTokens>(tokens: T, message: &str) -> TokenStream {
8    syn::Error::new_spanned(tokens, message)
9        .to_compile_error()
10        .into()
11}
12
13#[proc_macro_attribute]
14pub fn message_handler(_attr: TokenStream, item: TokenStream) -> TokenStream {
15    let input = parse_macro_input!(item as ItemImpl);
16    let Some((_, trait_path, _)) = input.trait_.as_ref() else {
17        return compile_error(&input.self_ty, "message_handler must be used on an impl of MessageHandler");
18    };
19    let is_message_handler = trait_path
20        .segments
21        .last()
22        .map(|seg| seg.ident == "MessageHandler")
23        .unwrap_or(false);
24    if !is_message_handler {
25        return compile_error(trait_path, "message_handler must be used on an impl of MessageHandler");
26    }
27    if !input.generics.params.is_empty() {
28        return compile_error(&input.generics, "message_handler does not support generic impls");
29    }
30    let ty = *input.self_ty.clone();
31    let expanded = quote! {
32        #[::adaptivemsg::async_trait]
33        #input
34        ::adaptivemsg::submit_message_handler!(#ty);
35        ::adaptivemsg::submit_message!(#ty);
36    };
37    TokenStream::from(expanded)
38}
39
40#[proc_macro_attribute]
41pub fn message(attr: TokenStream, item: TokenStream) -> TokenStream {
42    let mut ns: Option<LitStr> = None;
43    let mut base_name: Option<LitStr> = None;
44    let mut register: bool = false;
45    let parser = syn::meta::parser(|meta| {
46        if meta.path.is_ident("ns") {
47            let lit: LitStr = meta.value()?.parse()?;
48            ns = Some(lit);
49            return Ok(());
50        }
51        if meta.path.is_ident("name") {
52            let lit: LitStr = meta.value()?.parse()?;
53            base_name = Some(lit);
54            return Ok(());
55        }
56        if meta.path.is_ident("register") {
57            register = true;
58            return Ok(());
59        }
60        Err(meta.error("unsupported message attribute; use ns=\"...\", name=\"...\", or register"))
61    });
62    if let Err(err) = parser.parse(attr.into()) {
63        return err.to_compile_error().into();
64    }
65
66    let input = parse_macro_input!(item as ItemStruct);
67    let name = &input.ident;
68    if !input.generics.params.is_empty() {
69        return compile_error(&input.generics, "message does not support generic structs");
70    }
71    let fields = match &input.fields {
72        Fields::Named(fields) => fields,
73        _ => {
74            return compile_error(
75                &input.ident,
76                "message only supports structs with named fields",
77            )
78        }
79    };
80    let field_count = fields.named.len();
81    let encode_fields = fields.named.iter().map(|field| {
82        let ident = field.ident.as_ref().unwrap();
83        quote! {
84            items.push(::adaptivemsg::__private::rmpv::ext::to_value(&self.#ident)?);
85        }
86    });
87    let decode_fields = fields.named.iter().map(|field| {
88        let ident = field.ident.as_ref().unwrap();
89        let ty = &field.ty;
90        quote! {
91            let #ident: #ty = ::adaptivemsg::__private::rmpv::ext::from_value(iter.next().unwrap())?;
92        }
93    });
94    let init_fields = fields.named.iter().map(|field| {
95        let ident = field.ident.as_ref().unwrap();
96        quote! { #ident }
97    });
98    let ns_lit = ns.unwrap_or_else(|| LitStr::new("am", Span::call_site()));
99    let base_expr = if let Some(base_name) = base_name {
100        quote! { #base_name.to_string() }
101    } else {
102        quote! {{
103            let module_leaf = ::core::module_path!()
104                .rsplit("::")
105                .next()
106                .unwrap_or("unknown");
107            format!("{}.{}", module_leaf, stringify!(#name))
108        }}
109    };
110    let register_submit = if register {
111        quote! { ::adaptivemsg::submit_message!(#name); }
112    } else {
113        quote! {}
114    };
115    let expanded = quote! {
116        #[derive(::serde::Serialize, ::serde::Deserialize)]
117        #input
118        impl ::adaptivemsg::Message for #name {
119            fn wire_name(&self) -> &'static str {
120                Self::wire_name_static()
121            }
122
123            fn wire_name_static() -> &'static str {
124                static WIRE_NAME: ::std::sync::OnceLock<String> = ::std::sync::OnceLock::new();
125                WIRE_NAME.get_or_init(|| {
126                    let ns = #ns_lit;
127                    let base = #base_expr;
128                    format!("{ns}.{base}")
129                }).as_str()
130            }
131
132            fn encode_map(&self) -> ::std::result::Result<Vec<u8>, ::adaptivemsg::Error> {
133                #[derive(::serde::Serialize)]
134                struct Envelope<'a, T: ::serde::Serialize> {
135                    r#type: &'a str,
136                    data: &'a T,
137                }
138                let env = Envelope {
139                    r#type: Self::wire_name_static(),
140                    data: self,
141                };
142                ::adaptivemsg::__private::rmp_serde::to_vec_named(&env).map_err(::adaptivemsg::Error::from)
143            }
144
145            fn encode_compact(&self) -> ::std::result::Result<Vec<u8>, ::adaptivemsg::Error> {
146                let mut items = Vec::with_capacity(1 + #field_count);
147                items.push(::adaptivemsg::__private::rmpv::Value::String(::adaptivemsg::__private::rmpv::Utf8String::from(Self::wire_name_static())));
148                #(#encode_fields)*
149                let value = ::adaptivemsg::__private::rmpv::Value::Array(items);
150                let mut buf = Vec::new();
151                ::adaptivemsg::__private::rmpv::encode::write_value(&mut buf, &value)?;
152                Ok(buf)
153            }
154
155            fn encode_postcard(&self) -> ::std::result::Result<Vec<u8>, ::adaptivemsg::Error> {
156                ::adaptivemsg::__private::postcard::to_stdvec(self).map_err(::adaptivemsg::Error::from)
157            }
158
159            fn as_any(&self) -> &dyn ::core::any::Any {
160                self
161            }
162        }
163
164        impl ::adaptivemsg::__private::MessageDecode for #name {
165            fn decode_map(value: ::adaptivemsg::__private::rmpv::Value) -> ::std::result::Result<Self, ::adaptivemsg::Error> {
166                ::adaptivemsg::__private::rmpv::ext::from_value(value).map_err(::adaptivemsg::Error::from)
167            }
168
169            fn decode_compact(values: Vec<::adaptivemsg::__private::rmpv::Value>) -> ::std::result::Result<Self, ::adaptivemsg::Error> {
170                if values.len() != #field_count {
171                    return Err(::adaptivemsg::Error::CompactFieldCount {
172                        expected: #field_count,
173                        got: values.len(),
174                    });
175                }
176                let mut iter = values.into_iter();
177                #(#decode_fields)*
178                Ok(Self { #(#init_fields),* })
179            }
180
181            fn decode_postcard(payload: &[u8]) -> ::std::result::Result<Self, ::adaptivemsg::Error> {
182                ::adaptivemsg::__private::postcard::from_bytes(payload).map_err(::adaptivemsg::Error::from)
183            }
184        }
185        #register_submit
186    };
187    TokenStream::from(expanded)
188}