adaptivemsg_macros/
lib.rs1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::quote;
4use syn::parse::Parser;
5use syn::{parse_macro_input, Fields, ItemImpl, ItemStruct, LitStr};
6
7fn compile_error<T: quote::ToTokens>(tokens: T, message: &str) -> TokenStream {
8 syn::Error::new_spanned(tokens, message)
9 .to_compile_error()
10 .into()
11}
12
13#[proc_macro_attribute]
14pub fn message_handler(_attr: TokenStream, item: TokenStream) -> TokenStream {
15 let input = parse_macro_input!(item as ItemImpl);
16 let Some((_, trait_path, _)) = input.trait_.as_ref() else {
17 return compile_error(&input.self_ty, "message_handler must be used on an impl of MessageHandler");
18 };
19 let is_message_handler = trait_path
20 .segments
21 .last()
22 .map(|seg| seg.ident == "MessageHandler")
23 .unwrap_or(false);
24 if !is_message_handler {
25 return compile_error(trait_path, "message_handler must be used on an impl of MessageHandler");
26 }
27 if !input.generics.params.is_empty() {
28 return compile_error(&input.generics, "message_handler does not support generic impls");
29 }
30 let ty = *input.self_ty.clone();
31 let expanded = quote! {
32 #[::adaptivemsg::async_trait]
33 #input
34 ::adaptivemsg::submit_message_handler!(#ty);
35 ::adaptivemsg::submit_message!(#ty);
36 };
37 TokenStream::from(expanded)
38}
39
40#[proc_macro_attribute]
41pub fn message(attr: TokenStream, item: TokenStream) -> TokenStream {
42 let mut ns: Option<LitStr> = None;
43 let mut base_name: Option<LitStr> = None;
44 let mut register: bool = false;
45 let parser = syn::meta::parser(|meta| {
46 if meta.path.is_ident("ns") {
47 let lit: LitStr = meta.value()?.parse()?;
48 ns = Some(lit);
49 return Ok(());
50 }
51 if meta.path.is_ident("name") {
52 let lit: LitStr = meta.value()?.parse()?;
53 base_name = Some(lit);
54 return Ok(());
55 }
56 if meta.path.is_ident("register") {
57 register = true;
58 return Ok(());
59 }
60 Err(meta.error("unsupported message attribute; use ns=\"...\", name=\"...\", or register"))
61 });
62 if let Err(err) = parser.parse(attr.into()) {
63 return err.to_compile_error().into();
64 }
65
66 let input = parse_macro_input!(item as ItemStruct);
67 let name = &input.ident;
68 if !input.generics.params.is_empty() {
69 return compile_error(&input.generics, "message does not support generic structs");
70 }
71 let fields = match &input.fields {
72 Fields::Named(fields) => fields,
73 _ => {
74 return compile_error(
75 &input.ident,
76 "message only supports structs with named fields",
77 )
78 }
79 };
80 let field_count = fields.named.len();
81 let encode_fields = fields.named.iter().map(|field| {
82 let ident = field.ident.as_ref().unwrap();
83 quote! {
84 items.push(::adaptivemsg::__private::rmpv::ext::to_value(&self.#ident)?);
85 }
86 });
87 let decode_fields = fields.named.iter().map(|field| {
88 let ident = field.ident.as_ref().unwrap();
89 let ty = &field.ty;
90 quote! {
91 let #ident: #ty = ::adaptivemsg::__private::rmpv::ext::from_value(iter.next().unwrap())?;
92 }
93 });
94 let init_fields = fields.named.iter().map(|field| {
95 let ident = field.ident.as_ref().unwrap();
96 quote! { #ident }
97 });
98 let ns_lit = ns.unwrap_or_else(|| LitStr::new("am", Span::call_site()));
99 let base_expr = if let Some(base_name) = base_name {
100 quote! { #base_name.to_string() }
101 } else {
102 quote! {{
103 let module_leaf = ::core::module_path!()
104 .rsplit("::")
105 .next()
106 .unwrap_or("unknown");
107 format!("{}.{}", module_leaf, stringify!(#name))
108 }}
109 };
110 let register_submit = if register {
111 quote! { ::adaptivemsg::submit_message!(#name); }
112 } else {
113 quote! {}
114 };
115 let expanded = quote! {
116 #[derive(::serde::Serialize, ::serde::Deserialize)]
117 #input
118 impl ::adaptivemsg::Message for #name {
119 fn wire_name(&self) -> &'static str {
120 Self::wire_name_static()
121 }
122
123 fn wire_name_static() -> &'static str {
124 static WIRE_NAME: ::std::sync::OnceLock<String> = ::std::sync::OnceLock::new();
125 WIRE_NAME.get_or_init(|| {
126 let ns = #ns_lit;
127 let base = #base_expr;
128 format!("{ns}.{base}")
129 }).as_str()
130 }
131
132 fn encode_map(&self) -> ::std::result::Result<Vec<u8>, ::adaptivemsg::Error> {
133 #[derive(::serde::Serialize)]
134 struct Envelope<'a, T: ::serde::Serialize> {
135 r#type: &'a str,
136 data: &'a T,
137 }
138 let env = Envelope {
139 r#type: Self::wire_name_static(),
140 data: self,
141 };
142 ::adaptivemsg::__private::rmp_serde::to_vec_named(&env).map_err(::adaptivemsg::Error::from)
143 }
144
145 fn encode_compact(&self) -> ::std::result::Result<Vec<u8>, ::adaptivemsg::Error> {
146 let mut items = Vec::with_capacity(1 + #field_count);
147 items.push(::adaptivemsg::__private::rmpv::Value::String(::adaptivemsg::__private::rmpv::Utf8String::from(Self::wire_name_static())));
148 #(#encode_fields)*
149 let value = ::adaptivemsg::__private::rmpv::Value::Array(items);
150 let mut buf = Vec::new();
151 ::adaptivemsg::__private::rmpv::encode::write_value(&mut buf, &value)?;
152 Ok(buf)
153 }
154
155 fn encode_postcard(&self) -> ::std::result::Result<Vec<u8>, ::adaptivemsg::Error> {
156 ::adaptivemsg::__private::postcard::to_stdvec(self).map_err(::adaptivemsg::Error::from)
157 }
158
159 fn as_any(&self) -> &dyn ::core::any::Any {
160 self
161 }
162 }
163
164 impl ::adaptivemsg::__private::MessageDecode for #name {
165 fn decode_map(value: ::adaptivemsg::__private::rmpv::Value) -> ::std::result::Result<Self, ::adaptivemsg::Error> {
166 ::adaptivemsg::__private::rmpv::ext::from_value(value).map_err(::adaptivemsg::Error::from)
167 }
168
169 fn decode_compact(values: Vec<::adaptivemsg::__private::rmpv::Value>) -> ::std::result::Result<Self, ::adaptivemsg::Error> {
170 if values.len() != #field_count {
171 return Err(::adaptivemsg::Error::CompactFieldCount {
172 expected: #field_count,
173 got: values.len(),
174 });
175 }
176 let mut iter = values.into_iter();
177 #(#decode_fields)*
178 Ok(Self { #(#init_fields),* })
179 }
180
181 fn decode_postcard(payload: &[u8]) -> ::std::result::Result<Self, ::adaptivemsg::Error> {
182 ::adaptivemsg::__private::postcard::from_bytes(payload).map_err(::adaptivemsg::Error::from)
183 }
184 }
185 #register_submit
186 };
187 TokenStream::from(expanded)
188}