elfo_macros_impl/
message.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{quote, ToTokens};
4use syn::{
5    parenthesized,
6    parse::{Error as ParseError, Parse, ParseStream},
7    parse_macro_input, Data, DeriveInput, Ident, LitStr, Path, Token, Type,
8};
9
10#[derive(Debug)]
11struct MessageArgs {
12    name: Option<LitStr>,
13    ret: Option<Type>,
14    part: bool,
15    transparent: bool,
16    dumping_allowed: bool,
17    crate_: Option<Path>,
18    not: Vec<String>,
19}
20
21impl Parse for MessageArgs {
22    fn parse(input: ParseStream<'_>) -> Result<Self, ParseError> {
23        let mut args = MessageArgs {
24            ret: None,
25            name: None,
26            part: false,
27            transparent: false,
28            dumping_allowed: true,
29            crate_: None,
30            not: Vec::new(),
31        };
32
33        // `#[message]`
34        // `#[message(name = "N")]`
35        // `#[message(ret = A)]`
36        // `#[message(part)]`
37        // `#[message(part, transparent)]`
38        // `#[message(elfo = some)]`
39        // `#[message(not(Debug))]`
40        // `#[message(dumping = "disabled")]`
41        while !input.is_empty() {
42            let ident: Ident = input.parse()?;
43
44            match ident.to_string().as_str() {
45                "name" => {
46                    let _: Token![=] = input.parse()?;
47                    args.name = Some(input.parse()?);
48                }
49                "ret" => {
50                    let _: Token![=] = input.parse()?;
51                    args.ret = Some(input.parse()?);
52                }
53                "part" => args.part = true,
54                "transparent" => args.transparent = true,
55                "dumping" => {
56                    // TODO: introduce `DumpingMode`.
57                    let _: Token![=] = input.parse()?;
58                    let s: LitStr = input.parse()?;
59
60                    assert_eq!(
61                        s.value(),
62                        "disabled",
63                        "only `dumping = \"disabled\"` is supported"
64                    );
65
66                    args.dumping_allowed = false;
67                }
68                // TODO: call it `crate` like in linkme?
69                "elfo" => {
70                    let _: Token![=] = input.parse()?;
71                    args.crate_ = Some(input.parse()?);
72                }
73                "not" => {
74                    let content;
75                    parenthesized!(content in input);
76                    args.not = content
77                        .parse_terminated::<_, Token![,]>(Ident::parse)?
78                        .iter()
79                        .map(|ident| ident.to_string())
80                        .collect();
81                }
82                attr => panic!("invalid attribute: {attr}"),
83            }
84
85            if !input.is_empty() {
86                let _: Token![,] = input.parse()?;
87            }
88        }
89
90        Ok(args)
91    }
92}
93
94fn gen_derive_attr(blacklist: &[String], name: &str, path: TokenStream2) -> TokenStream2 {
95    let tokens = if blacklist.iter().all(|x| x != name) {
96        quote! { #[derive(#path)] }
97    } else {
98        quote! {}
99    };
100
101    tokens.into_token_stream()
102}
103
104// TODO: add `T: Debug` for type arguments.
105fn gen_impl_debug(input: &DeriveInput) -> TokenStream2 {
106    let name = &input.ident;
107    let field = match &input.data {
108        Data::Struct(data) => {
109            assert_eq!(
110                data.fields.len(),
111                1,
112                "`transparent` is applicable only for structs with one field"
113            );
114            data.fields.iter().next().unwrap()
115        }
116        Data::Enum(_) => panic!("`transparent` is applicable for structs only"),
117        Data::Union(_) => panic!("`transparent` is applicable for structs only"),
118    };
119
120    let propagate_fmt = if let Some(ident) = field.ident.as_ref() {
121        quote! { self.#ident.fmt(f) }
122    } else {
123        quote! { self.0.fmt(f) }
124    };
125
126    quote! {
127        impl ::std::fmt::Debug for #name {
128            #[inline]
129            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
130                #propagate_fmt
131            }
132        }
133    }
134}
135
136pub fn message_impl(
137    args: TokenStream,
138    input: TokenStream,
139    default_path_to_elfo: Path,
140) -> TokenStream {
141    let args = parse_macro_input!(args as MessageArgs);
142    let crate_ = args.crate_.unwrap_or(default_path_to_elfo);
143
144    // TODO: what about parsing into something cheaper?
145    let input = parse_macro_input!(input as DeriveInput);
146    let name = &input.ident;
147    let serde_crate = format!("{}::_priv::serde", crate_.to_token_stream());
148    let internal = quote![#crate_::_priv];
149
150    let protocol = std::env::var("CARGO_PKG_NAME").expect("building without cargo?");
151
152    let name_str = args
153        .name
154        .as_ref()
155        .map(LitStr::value)
156        .unwrap_or_else(|| input.ident.to_string());
157
158    let ret_wrapper = if let Some(ret) = &args.ret {
159        let wrapper_name_str = format!("{name_str}::Response");
160
161        quote! {
162            #[message(not(Debug), name = #wrapper_name_str, elfo = #crate_)]
163            pub struct _elfo_Wrapper(#ret);
164
165            impl fmt::Debug for _elfo_Wrapper {
166                #[inline]
167                fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
168                    self.0.fmt(f)
169                }
170            }
171
172            impl From<#ret> for _elfo_Wrapper {
173                #[inline]
174                fn from(inner: #ret) -> Self {
175                    _elfo_Wrapper(inner)
176                }
177            }
178
179            impl From<_elfo_Wrapper> for #ret {
180                #[inline]
181                fn from(wrapper: _elfo_Wrapper) -> Self {
182                    wrapper.0
183                }
184            }
185        }
186    } else {
187        quote! {}
188    };
189
190    let derive_debug = if !args.transparent {
191        gen_derive_attr(&args.not, "Debug", quote![Debug])
192    } else {
193        Default::default()
194    };
195    let derive_clone = gen_derive_attr(&args.not, "Clone", quote![Clone]);
196    let derive_serialize =
197        gen_derive_attr(&args.not, "Serialize", quote![#internal::serde::Serialize]);
198    let derive_deserialize = gen_derive_attr(
199        &args.not,
200        "Deserialize",
201        quote![#internal::serde::Deserialize],
202    );
203
204    let serde_crate_attr = if !derive_serialize.is_empty() || !derive_deserialize.is_empty() {
205        quote! { #[serde(crate = #serde_crate)] }
206    } else {
207        quote! {}
208    };
209
210    let serde_transparent_attr = if args.transparent {
211        quote! { #[serde(transparent)] }
212    } else {
213        quote! {}
214    };
215
216    // TODO: pass to `_elfo_Wrapper`.
217    let dumping_allowed = args.dumping_allowed;
218
219    let impl_message = if !args.part {
220        quote! {
221            impl #crate_::Message for #name {
222                const VTABLE: &'static #internal::MessageVTable = VTABLE;
223
224                #[inline(always)]
225                fn _touch(&self) {
226                    touch();
227                }
228            }
229
230            #ret_wrapper
231
232            fn cast_ref(message: &#internal::AnyMessage) -> &#name {
233                message.downcast_ref::<#name>().expect("invalid vtable")
234            }
235
236            fn clone(message: &#internal::AnyMessage) -> #internal::AnyMessage {
237                #internal::AnyMessage::new(Clone::clone(cast_ref(message)))
238            }
239
240            fn debug(message: &#internal::AnyMessage, f: &mut fmt::Formatter<'_>) -> fmt::Result {
241                fmt::Debug::fmt(cast_ref(message), f)
242            }
243
244            fn erase(message: &#internal::AnyMessage) -> #crate_::dumping::ErasedMessage {
245                smallbox!(Clone::clone(cast_ref(message)))
246            }
247
248            const VTABLE: &'static #internal::MessageVTable = &#internal::MessageVTable {
249                name: #name_str,
250                protocol: #protocol,
251                labels: &[
252                    metrics::Label::from_static_parts("message", #name_str),
253                    metrics::Label::from_static_parts("protocol", #protocol),
254                ],
255                dumping_allowed: #dumping_allowed,
256                clone,
257                debug,
258                erase,
259            };
260
261            #[linkme::distributed_slice(MESSAGE_LIST)]
262            #[linkme(crate = linkme)]
263            static VTABLE_STATIC: &'static #internal::MessageVTable = <#name as #crate_::Message>::VTABLE;
264
265            // See [rust#47384](https://github.com/rust-lang/rust/issues/47384).
266            #[doc(hidden)]
267            #[inline(never)]
268            pub fn touch() {}
269        }
270    } else {
271        quote! {}
272    };
273
274    let impl_request = if let Some(ret) = &args.ret {
275        assert!(!args.part, "`part` and `ret` attributes are incompatible");
276
277        quote! {
278            impl #crate_::Request for #name {
279                type Response = #ret;
280                type Wrapper = _elfo_Wrapper;
281            }
282        }
283    } else {
284        quote! {}
285    };
286
287    let impl_debug = if args.transparent && args.not.iter().all(|x| x != "Debug") {
288        gen_impl_debug(&input)
289    } else {
290        quote! {}
291    };
292
293    TokenStream::from(quote! {
294        #derive_debug
295        #derive_clone
296        #derive_serialize
297        #derive_deserialize
298        #serde_crate_attr
299        #serde_transparent_attr
300        #input
301
302        #[doc(hidden)]
303        #[allow(non_snake_case)]
304        const _: () = {
305            // Keep this list as minimal as possible to avoid possible collisions with `#name`.
306            // Especially avoid `PascalCase`.
307            use ::std::fmt;
308            use #internal::{MESSAGE_LIST, smallbox::smallbox, linkme, metrics};
309
310            #impl_message
311            #impl_request
312            #impl_debug
313        };
314    })
315}